Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/_src/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/__pycache__/magic_trace.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/delayed_mul_tensor.py +77 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/magic_trace.py +42 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/wrap_type.py +71 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__init__.py +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/_parsing.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/rearrange.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/rearrange.py +207 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/__pycache__/extrapolation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/__pycache__/quadrature.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/approximation.py +246 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/__init__.py +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/__pycache__/calculus.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/calculus.py +531 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/eigen.py +877 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/linalg.py +790 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/matrices.py +1005 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h +1853 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h +273 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/include/cupti_pcsampling.h +923 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_infer_v8.h +571 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version.h +70 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/cudalibxt.h +97 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/cufftXt.h +269 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/include/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia_nvtx_cu11-11.8.86.dist-info/METADATA +35 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia_nvtx_cu11-11.8.86.dist-info/WHEEL +5 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/__pycache__/_elffile.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/_parser.py +354 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/markers.py +331 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/metadata.py +863 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/py.typed +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/specifiers.py +1020 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/tags.py +617 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/version.py +582 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/cachecontrol/__init__.py +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/cachecontrol/__pycache__/__init__.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/_src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (215 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/__pycache__/magic_trace.cpython-311.pyc
ADDED
|
Binary file (2.48 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/delayed_mul_tensor.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import _Tensor, Tensor
|
| 9 |
+
from .reference import _dims, _enable_layers, llist, ltuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DelayedMulTensor(_Tensor):
|
| 13 |
+
def __init__(self, lhs, rhs):
|
| 14 |
+
self._lhs, self._rhs = lhs, rhs
|
| 15 |
+
self._data = None
|
| 16 |
+
self._levels_data = None
|
| 17 |
+
self._has_device = lhs._has_device or rhs._has_device
|
| 18 |
+
self._batchtensor_data = None
|
| 19 |
+
self._tensor_data = None
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def _levels(self):
|
| 23 |
+
if self._levels_data is None:
|
| 24 |
+
levels = llist(self._lhs._levels)
|
| 25 |
+
for l in self._rhs._levels:
|
| 26 |
+
if l not in levels:
|
| 27 |
+
levels.append(l)
|
| 28 |
+
self._levels_data = ltuple(levels)
|
| 29 |
+
return self._levels_data
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def _batchtensor(self):
|
| 33 |
+
if self._batchtensor_data is None:
|
| 34 |
+
with _enable_layers(self._levels):
|
| 35 |
+
print("bt multiply fallback")
|
| 36 |
+
self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
|
| 37 |
+
return self._batchtensor_data
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def _tensor(self):
|
| 41 |
+
if self._tensor_data is None:
|
| 42 |
+
self._tensor_data = Tensor.from_batched(
|
| 43 |
+
self._batchtensor, self._has_device
|
| 44 |
+
)._tensor
|
| 45 |
+
return self._tensor_data
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def ndim(self):
|
| 49 |
+
return self._batchtensor.ndim
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def dims(self):
|
| 53 |
+
return ltuple(super().dims)
|
| 54 |
+
|
| 55 |
+
def sum(self, dim):
|
| 56 |
+
dims = _dims(dim, 0, False, False)
|
| 57 |
+
n = ord("a")
|
| 58 |
+
all_levels = self._levels
|
| 59 |
+
|
| 60 |
+
def to_char(d):
|
| 61 |
+
return chr(n + all_levels.index(d))
|
| 62 |
+
|
| 63 |
+
plhs, levelslhs = self._lhs._tensor, self._lhs._levels
|
| 64 |
+
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
|
| 65 |
+
new_dims = tuple(d for d in self.dims if d not in dims)
|
| 66 |
+
new_levels = [l for l in self._levels if l not in dims]
|
| 67 |
+
fmt = "".join(
|
| 68 |
+
[
|
| 69 |
+
*(to_char(d) for d in levelslhs),
|
| 70 |
+
",",
|
| 71 |
+
*(to_char(d) for d in levelsrhs),
|
| 72 |
+
"->",
|
| 73 |
+
*(to_char(d) for d in new_levels),
|
| 74 |
+
]
|
| 75 |
+
)
|
| 76 |
+
result_data = torch.einsum(fmt, (plhs, prhs))
|
| 77 |
+
return Tensor.from_positional(result_data, new_levels, True)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/magic_trace.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import os
|
| 7 |
+
import signal
|
| 8 |
+
import subprocess
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@contextmanager
|
| 13 |
+
def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
|
| 14 |
+
pid = os.getpid()
|
| 15 |
+
if not os.path.exists(magic_trace_cache):
|
| 16 |
+
print(f"Downloading magic_trace to: {magic_trace_cache}")
|
| 17 |
+
subprocess.run(
|
| 18 |
+
[
|
| 19 |
+
"wget",
|
| 20 |
+
"-O",
|
| 21 |
+
magic_trace_cache,
|
| 22 |
+
"-q",
|
| 23 |
+
"https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
|
| 24 |
+
]
|
| 25 |
+
)
|
| 26 |
+
subprocess.run(["chmod", "+x", magic_trace_cache])
|
| 27 |
+
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
|
| 28 |
+
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
|
| 29 |
+
while True:
|
| 30 |
+
x = p.stderr.readline()
|
| 31 |
+
print(x)
|
| 32 |
+
if "Attached" in x:
|
| 33 |
+
break
|
| 34 |
+
try:
|
| 35 |
+
yield
|
| 36 |
+
finally:
|
| 37 |
+
p.send_signal(signal.SIGINT)
|
| 38 |
+
r = p.wait()
|
| 39 |
+
print(p.stderr.read())
|
| 40 |
+
p.stderr.close()
|
| 41 |
+
if r != 0:
|
| 42 |
+
raise ValueError(f"magic_trace exited abnormally: {r}")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/dim/wrap_type.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from types import (
|
| 8 |
+
BuiltinMethodType,
|
| 9 |
+
FunctionType,
|
| 10 |
+
GetSetDescriptorType,
|
| 11 |
+
MethodDescriptorType,
|
| 12 |
+
WrapperDescriptorType,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from functorch._C import dim as _C
|
| 16 |
+
|
| 17 |
+
_wrap_method = _C._wrap_method
|
| 18 |
+
|
| 19 |
+
FUNC_TYPES = (
|
| 20 |
+
FunctionType,
|
| 21 |
+
MethodDescriptorType,
|
| 22 |
+
BuiltinMethodType,
|
| 23 |
+
WrapperDescriptorType,
|
| 24 |
+
)
|
| 25 |
+
PROPERTY_TYPES = (GetSetDescriptorType, property)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _py_wrap_method(orig, __torch_function__):
|
| 29 |
+
def impl(*args, **kwargs):
|
| 30 |
+
return __torch_function__(orig, None, args, kwargs)
|
| 31 |
+
|
| 32 |
+
return impl
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def wrap_type(use_c, to_patch, pattern, __torch_function__):
|
| 36 |
+
if use_c:
|
| 37 |
+
wrap_method = _wrap_method
|
| 38 |
+
else:
|
| 39 |
+
wrap_method = _py_wrap_method
|
| 40 |
+
|
| 41 |
+
all = {}
|
| 42 |
+
for t in reversed(pattern.mro()[:-1]): # skip object
|
| 43 |
+
all.update(t.__dict__)
|
| 44 |
+
|
| 45 |
+
def wrap_attr(orig):
|
| 46 |
+
return property(wrap_method(orig.__get__, __torch_function__))
|
| 47 |
+
|
| 48 |
+
for name, obj in all.items():
|
| 49 |
+
if name in (
|
| 50 |
+
"__dict__",
|
| 51 |
+
"__new__",
|
| 52 |
+
"__init__",
|
| 53 |
+
"__repr__",
|
| 54 |
+
"__weakref__",
|
| 55 |
+
"__doc__",
|
| 56 |
+
"__module__",
|
| 57 |
+
"__dir__",
|
| 58 |
+
):
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
# skip things that have been overloaded
|
| 62 |
+
# things that come from object like `__eq__` still need to be patched, however.
|
| 63 |
+
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
|
| 64 |
+
object, name, None
|
| 65 |
+
):
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
if isinstance(obj, FUNC_TYPES):
|
| 69 |
+
setattr(to_patch, name, wrap_method(obj, __torch_function__))
|
| 70 |
+
elif isinstance(obj, PROPERTY_TYPES):
|
| 71 |
+
setattr(to_patch, name, wrap_attr(obj))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .rearrange import rearrange
|
| 2 |
+
|
| 3 |
+
__all__ = ["rearrange"]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (288 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/_parsing.cpython-311.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/__pycache__/rearrange.cpython-311.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/functorch/einops/rearrange.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
from typing import Callable, Dict, List, Sequence, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from functorch._C import dim as _C
|
| 9 |
+
from ._parsing import (
|
| 10 |
+
_ellipsis,
|
| 11 |
+
AnonymousAxis,
|
| 12 |
+
comma_separate,
|
| 13 |
+
parse_pattern,
|
| 14 |
+
validate_rearrange_expressions,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = ["rearrange"]
|
| 18 |
+
|
| 19 |
+
dims = _C.dims
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@functools.lru_cache(256)
|
| 23 |
+
def _create_rearrange_callable(
|
| 24 |
+
tensor_ndim: int, pattern: str, **axes_lengths: int
|
| 25 |
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 26 |
+
r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
|
| 27 |
+
|
| 28 |
+
Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
|
| 29 |
+
specified axes lengths, this function can be memoized.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
tensor_ndim (int): the number of dimensions in the tensor to rearrange
|
| 33 |
+
pattern (str): the `einops`-style rearrangement pattern
|
| 34 |
+
axes_lengths (int): any additional length specifications for dimensions
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
|
| 38 |
+
"""
|
| 39 |
+
left, right = parse_pattern(pattern, axes_lengths)
|
| 40 |
+
validate_rearrange_expressions(left, right, axes_lengths)
|
| 41 |
+
|
| 42 |
+
n_anon_dims = sum(not dim for dim in left.composition)
|
| 43 |
+
if left.has_ellipsis:
|
| 44 |
+
n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
|
| 45 |
+
n_named_dims = len(left.identifiers) - 1
|
| 46 |
+
|
| 47 |
+
if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
|
| 50 |
+
f"dimensions in the tensor ({tensor_ndim})"
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
n_ellipsis_dims = 0
|
| 54 |
+
n_named_dims = len(left.identifiers)
|
| 55 |
+
|
| 56 |
+
if (pattern_ndim := len(left.composition)) != tensor_ndim:
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
|
| 59 |
+
f"the tensor ({tensor_ndim})"
|
| 60 |
+
)
|
| 61 |
+
n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
|
| 62 |
+
|
| 63 |
+
if n_dims == 0:
|
| 64 |
+
# an identity rearrangement on a 0-dimension tensor
|
| 65 |
+
return lambda tensor: tensor
|
| 66 |
+
|
| 67 |
+
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
| 68 |
+
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
|
| 69 |
+
anon_axes: List[AnonymousAxis] = []
|
| 70 |
+
|
| 71 |
+
# map the left-hand side identifiers to strings representing first class dims
|
| 72 |
+
dims_i = 0
|
| 73 |
+
for dimension in left.composition:
|
| 74 |
+
if isinstance(dimension, list):
|
| 75 |
+
for identifier in dimension:
|
| 76 |
+
# non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
|
| 77 |
+
assert isinstance(identifier, str)
|
| 78 |
+
identifier_dim_map[identifier] = (first_class_dims[dims_i],)
|
| 79 |
+
dims_i += 1
|
| 80 |
+
if not dimension:
|
| 81 |
+
# unitary anonymous axis
|
| 82 |
+
anon_axis = AnonymousAxis("1")
|
| 83 |
+
identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
|
| 84 |
+
anon_axes.append(anon_axis)
|
| 85 |
+
dimension.append(anon_axis)
|
| 86 |
+
dims_i += 1
|
| 87 |
+
elif dimension == _ellipsis:
|
| 88 |
+
identifier = _ellipsis
|
| 89 |
+
identifier_dim_map[identifier] = tuple(
|
| 90 |
+
first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
|
| 91 |
+
)
|
| 92 |
+
dims_i += n_ellipsis_dims
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError(f"Unexpected dimension: {dimension}")
|
| 95 |
+
|
| 96 |
+
def composition_to_dims(
|
| 97 |
+
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
|
| 98 |
+
) -> List[Union[str, Tuple[str, ...]]]:
|
| 99 |
+
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
|
| 100 |
+
class dims."""
|
| 101 |
+
dim_composition: List[Union[str, Tuple[str, ...]]] = []
|
| 102 |
+
for dimension in composition:
|
| 103 |
+
if isinstance(dimension, list):
|
| 104 |
+
dim_composition.append(
|
| 105 |
+
tuple(
|
| 106 |
+
dim
|
| 107 |
+
for identifier in dimension
|
| 108 |
+
for dim in identifier_dim_map[identifier]
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
elif dimension == _ellipsis:
|
| 112 |
+
dim_composition.extend(identifier_dim_map[_ellipsis])
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError(f"Unexpected dimension: {dimension}")
|
| 115 |
+
return dim_composition
|
| 116 |
+
|
| 117 |
+
left_dims = composition_to_dims(left.composition)
|
| 118 |
+
right_dims = composition_to_dims(right.composition)
|
| 119 |
+
anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
|
| 120 |
+
specified_lengths = tuple(
|
| 121 |
+
(identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
custom_rearrange_callable_name = "do_rearrange"
|
| 125 |
+
custom_rearrange_callable_code = (
|
| 126 |
+
(
|
| 127 |
+
f"def {custom_rearrange_callable_name}(tensor):\n"
|
| 128 |
+
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
|
| 129 |
+
)
|
| 130 |
+
+ (
|
| 131 |
+
"".join(
|
| 132 |
+
f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
|
| 133 |
+
)
|
| 134 |
+
if specified_lengths
|
| 135 |
+
else ""
|
| 136 |
+
)
|
| 137 |
+
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
|
| 138 |
+
+ (
|
| 139 |
+
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
|
| 140 |
+
if anon_dims
|
| 141 |
+
else " return tensor\n"
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
exec(custom_rearrange_callable_code)
|
| 146 |
+
return locals()[custom_rearrange_callable_name]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def rearrange(
|
| 150 |
+
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
| 151 |
+
pattern: str,
|
| 152 |
+
**axes_lengths: int,
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
+
r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
|
| 155 |
+
tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
|
| 156 |
+
stack, concatenate and other operations.
|
| 157 |
+
|
| 158 |
+
See: https://einops.rocks/api/rearrange/
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
|
| 162 |
+
pattern (str): the rearrangement pattern
|
| 163 |
+
axes_lengths (int): any additional length specifications for dimensions
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Tensor: the rearranged tensor
|
| 167 |
+
|
| 168 |
+
Examples:
|
| 169 |
+
>>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
|
| 170 |
+
>>> images = torch.randn((32, 30, 40, 3))
|
| 171 |
+
|
| 172 |
+
>>> # stack along first (batch) axis, output is a single array
|
| 173 |
+
>>> rearrange(images, 'b h w c -> b h w c').shape
|
| 174 |
+
torch.Size([32, 30, 40, 3])
|
| 175 |
+
|
| 176 |
+
>>> # concatenate images along height (vertical axis), 960 = 32 * 30
|
| 177 |
+
>>> rearrange(images, 'b h w c -> (b h) w c').shape
|
| 178 |
+
torch.Size([960, 40, 3])
|
| 179 |
+
|
| 180 |
+
>>> # concatenated images along horizontal axis, 1280 = 32 * 40
|
| 181 |
+
>>> rearrange(images, 'b h w c -> h (b w) c').shape
|
| 182 |
+
torch.Size([30, 1280, 3])
|
| 183 |
+
|
| 184 |
+
>>> # reordered axes to "b c h w" format for deep learning
|
| 185 |
+
>>> rearrange(images, 'b h w c -> b c h w').shape
|
| 186 |
+
torch.Size([32, 3, 30, 40])
|
| 187 |
+
|
| 188 |
+
>>> # flattened each image into a vector, 3600 = 30 * 40 * 3
|
| 189 |
+
>>> rearrange(images, 'b h w c -> b (c h w)').shape
|
| 190 |
+
torch.Size([32, 3600])
|
| 191 |
+
|
| 192 |
+
>>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
|
| 193 |
+
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
|
| 194 |
+
torch.Size([128, 15, 20, 3])
|
| 195 |
+
|
| 196 |
+
>>> # space-to-depth operation
|
| 197 |
+
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
|
| 198 |
+
torch.Size([32, 15, 20, 12])
|
| 199 |
+
"""
|
| 200 |
+
if not isinstance(tensor, torch.Tensor):
|
| 201 |
+
tensor = torch.stack(tensor)
|
| 202 |
+
|
| 203 |
+
rearrange_callable = _create_rearrange_callable(
|
| 204 |
+
tensor.ndim, pattern, **axes_lengths
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return rearrange_callable(tensor)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/__pycache__/extrapolation.cpython-311.pyc
ADDED
|
Binary file (89.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/__pycache__/quadrature.cpython-311.pyc
ADDED
|
Binary file (50.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/calculus/approximation.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..libmp.backend import xrange
|
| 2 |
+
from .calculus import defun
|
| 3 |
+
|
| 4 |
+
#----------------------------------------------------------------------------#
|
| 5 |
+
# Approximation methods #
|
| 6 |
+
#----------------------------------------------------------------------------#
|
| 7 |
+
|
| 8 |
+
# The Chebyshev approximation formula is given at:
|
| 9 |
+
# http://mathworld.wolfram.com/ChebyshevApproximationFormula.html
|
| 10 |
+
|
| 11 |
+
# The only major changes in the following code is that we return the
|
| 12 |
+
# expanded polynomial coefficients instead of Chebyshev coefficients,
|
| 13 |
+
# and that we automatically transform [a,b] -> [-1,1] and back
|
| 14 |
+
# for convenience.
|
| 15 |
+
|
| 16 |
+
# Coefficient in Chebyshev approximation
|
| 17 |
+
def chebcoeff(ctx,f,a,b,j,N):
|
| 18 |
+
s = ctx.mpf(0)
|
| 19 |
+
h = ctx.mpf(0.5)
|
| 20 |
+
for k in range(1, N+1):
|
| 21 |
+
t = ctx.cospi((k-h)/N)
|
| 22 |
+
s += f(t*(b-a)*h + (b+a)*h) * ctx.cospi(j*(k-h)/N)
|
| 23 |
+
return 2*s/N
|
| 24 |
+
|
| 25 |
+
# Generate Chebyshev polynomials T_n(ax+b) in expanded form
|
| 26 |
+
def chebT(ctx, a=1, b=0):
|
| 27 |
+
Tb = [1]
|
| 28 |
+
yield Tb
|
| 29 |
+
Ta = [b, a]
|
| 30 |
+
while 1:
|
| 31 |
+
yield Ta
|
| 32 |
+
# Recurrence: T[n+1](ax+b) = 2*(ax+b)*T[n](ax+b) - T[n-1](ax+b)
|
| 33 |
+
Tmp = [0] + [2*a*t for t in Ta]
|
| 34 |
+
for i, c in enumerate(Ta): Tmp[i] += 2*b*c
|
| 35 |
+
for i, c in enumerate(Tb): Tmp[i] -= c
|
| 36 |
+
Ta, Tb = Tmp, Ta
|
| 37 |
+
|
| 38 |
+
@defun
|
| 39 |
+
def chebyfit(ctx, f, interval, N, error=False):
|
| 40 |
+
r"""
|
| 41 |
+
Computes a polynomial of degree `N-1` that approximates the
|
| 42 |
+
given function `f` on the interval `[a, b]`. With ``error=True``,
|
| 43 |
+
:func:`~mpmath.chebyfit` also returns an accurate estimate of the
|
| 44 |
+
maximum absolute error; that is, the maximum value of
|
| 45 |
+
`|f(x) - P(x)|` for `x \in [a, b]`.
|
| 46 |
+
|
| 47 |
+
:func:`~mpmath.chebyfit` uses the Chebyshev approximation formula,
|
| 48 |
+
which gives a nearly optimal solution: that is, the maximum
|
| 49 |
+
error of the approximating polynomial is very close to
|
| 50 |
+
the smallest possible for any polynomial of the same degree.
|
| 51 |
+
|
| 52 |
+
Chebyshev approximation is very useful if one needs repeated
|
| 53 |
+
evaluation of an expensive function, such as function defined
|
| 54 |
+
implicitly by an integral or a differential equation. (For
|
| 55 |
+
example, it could be used to turn a slow mpmath function
|
| 56 |
+
into a fast machine-precision version of the same.)
|
| 57 |
+
|
| 58 |
+
**Examples**
|
| 59 |
+
|
| 60 |
+
Here we use :func:`~mpmath.chebyfit` to generate a low-degree approximation
|
| 61 |
+
of `f(x) = \cos(x)`, valid on the interval `[1, 2]`::
|
| 62 |
+
|
| 63 |
+
>>> from mpmath import *
|
| 64 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 65 |
+
>>> poly, err = chebyfit(cos, [1, 2], 5, error=True)
|
| 66 |
+
>>> nprint(poly)
|
| 67 |
+
[0.00291682, 0.146166, -0.732491, 0.174141, 0.949553]
|
| 68 |
+
>>> nprint(err, 12)
|
| 69 |
+
1.61351758081e-5
|
| 70 |
+
|
| 71 |
+
The polynomial can be evaluated using ``polyval``::
|
| 72 |
+
|
| 73 |
+
>>> nprint(polyval(poly, 1.6), 12)
|
| 74 |
+
-0.0291858904138
|
| 75 |
+
>>> nprint(cos(1.6), 12)
|
| 76 |
+
-0.0291995223013
|
| 77 |
+
|
| 78 |
+
Sampling the true error at 1000 points shows that the error
|
| 79 |
+
estimate generated by ``chebyfit`` is remarkably good::
|
| 80 |
+
|
| 81 |
+
>>> error = lambda x: abs(cos(x) - polyval(poly, x))
|
| 82 |
+
>>> nprint(max([error(1+n/1000.) for n in range(1000)]), 12)
|
| 83 |
+
1.61349954245e-5
|
| 84 |
+
|
| 85 |
+
**Choice of degree**
|
| 86 |
+
|
| 87 |
+
The degree `N` can be set arbitrarily high, to obtain an
|
| 88 |
+
arbitrarily good approximation. As a rule of thumb, an
|
| 89 |
+
`N`-term Chebyshev approximation is good to `N/(b-a)` decimal
|
| 90 |
+
places on a unit interval (although this depends on how
|
| 91 |
+
well-behaved `f` is). The cost grows accordingly: ``chebyfit``
|
| 92 |
+
evaluates the function `(N^2)/2` times to compute the
|
| 93 |
+
coefficients and an additional `N` times to estimate the error.
|
| 94 |
+
|
| 95 |
+
**Possible issues**
|
| 96 |
+
|
| 97 |
+
One should be careful to use a sufficiently high working
|
| 98 |
+
precision both when calling ``chebyfit`` and when evaluating
|
| 99 |
+
the resulting polynomial, as the polynomial is sometimes
|
| 100 |
+
ill-conditioned. It is for example difficult to reach
|
| 101 |
+
15-digit accuracy when evaluating the polynomial using
|
| 102 |
+
machine precision floats, no matter the theoretical
|
| 103 |
+
accuracy of the polynomial. (The option to return the
|
| 104 |
+
coefficients in Chebyshev form should be made available
|
| 105 |
+
in the future.)
|
| 106 |
+
|
| 107 |
+
It is important to note the Chebyshev approximation works
|
| 108 |
+
poorly if `f` is not smooth. A function containing singularities,
|
| 109 |
+
rapid oscillation, etc can be approximated more effectively by
|
| 110 |
+
multiplying it by a weight function that cancels out the
|
| 111 |
+
nonsmooth features, or by dividing the interval into several
|
| 112 |
+
segments.
|
| 113 |
+
"""
|
| 114 |
+
a, b = ctx._as_points(interval)
|
| 115 |
+
orig = ctx.prec
|
| 116 |
+
try:
|
| 117 |
+
ctx.prec = orig + int(N**0.5) + 20
|
| 118 |
+
c = [chebcoeff(ctx,f,a,b,k,N) for k in range(N)]
|
| 119 |
+
d = [ctx.zero] * N
|
| 120 |
+
d[0] = -c[0]/2
|
| 121 |
+
h = ctx.mpf(0.5)
|
| 122 |
+
T = chebT(ctx, ctx.mpf(2)/(b-a), ctx.mpf(-1)*(b+a)/(b-a))
|
| 123 |
+
for (k, Tk) in zip(range(N), T):
|
| 124 |
+
for i in range(len(Tk)):
|
| 125 |
+
d[i] += c[k]*Tk[i]
|
| 126 |
+
d = d[::-1]
|
| 127 |
+
# Estimate maximum error
|
| 128 |
+
err = ctx.zero
|
| 129 |
+
for k in range(N):
|
| 130 |
+
x = ctx.cos(ctx.pi*k/N) * (b-a)*h + (b+a)*h
|
| 131 |
+
err = max(err, abs(f(x) - ctx.polyval(d, x)))
|
| 132 |
+
finally:
|
| 133 |
+
ctx.prec = orig
|
| 134 |
+
if error:
|
| 135 |
+
return d, +err
|
| 136 |
+
else:
|
| 137 |
+
return d
|
| 138 |
+
|
| 139 |
+
@defun
|
| 140 |
+
def fourier(ctx, f, interval, N):
|
| 141 |
+
r"""
|
| 142 |
+
Computes the Fourier series of degree `N` of the given function
|
| 143 |
+
on the interval `[a, b]`. More precisely, :func:`~mpmath.fourier` returns
|
| 144 |
+
two lists `(c, s)` of coefficients (the cosine series and sine
|
| 145 |
+
series, respectively), such that
|
| 146 |
+
|
| 147 |
+
.. math ::
|
| 148 |
+
|
| 149 |
+
f(x) \sim \sum_{k=0}^N
|
| 150 |
+
c_k \cos(k m x) + s_k \sin(k m x)
|
| 151 |
+
|
| 152 |
+
where `m = 2 \pi / (b-a)`.
|
| 153 |
+
|
| 154 |
+
Note that many texts define the first coefficient as `2 c_0` instead
|
| 155 |
+
of `c_0`. The easiest way to evaluate the computed series correctly
|
| 156 |
+
is to pass it to :func:`~mpmath.fourierval`.
|
| 157 |
+
|
| 158 |
+
**Examples**
|
| 159 |
+
|
| 160 |
+
The function `f(x) = x` has a simple Fourier series on the standard
|
| 161 |
+
interval `[-\pi, \pi]`. The cosine coefficients are all zero (because
|
| 162 |
+
the function has odd symmetry), and the sine coefficients are
|
| 163 |
+
rational numbers::
|
| 164 |
+
|
| 165 |
+
>>> from mpmath import *
|
| 166 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 167 |
+
>>> c, s = fourier(lambda x: x, [-pi, pi], 5)
|
| 168 |
+
>>> nprint(c)
|
| 169 |
+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
| 170 |
+
>>> nprint(s)
|
| 171 |
+
[0.0, 2.0, -1.0, 0.666667, -0.5, 0.4]
|
| 172 |
+
|
| 173 |
+
This computes a Fourier series of a nonsymmetric function on
|
| 174 |
+
a nonstandard interval::
|
| 175 |
+
|
| 176 |
+
>>> I = [-1, 1.5]
|
| 177 |
+
>>> f = lambda x: x**2 - 4*x + 1
|
| 178 |
+
>>> cs = fourier(f, I, 4)
|
| 179 |
+
>>> nprint(cs[0])
|
| 180 |
+
[0.583333, 1.12479, -1.27552, 0.904708, -0.441296]
|
| 181 |
+
>>> nprint(cs[1])
|
| 182 |
+
[0.0, -2.6255, 0.580905, 0.219974, -0.540057]
|
| 183 |
+
|
| 184 |
+
It is instructive to plot a function along with its truncated
|
| 185 |
+
Fourier series::
|
| 186 |
+
|
| 187 |
+
>>> plot([f, lambda x: fourierval(cs, I, x)], I) #doctest: +SKIP
|
| 188 |
+
|
| 189 |
+
Fourier series generally converge slowly (and may not converge
|
| 190 |
+
pointwise). For example, if `f(x) = \cosh(x)`, a 10-term Fourier
|
| 191 |
+
series gives an `L^2` error corresponding to 2-digit accuracy::
|
| 192 |
+
|
| 193 |
+
>>> I = [-1, 1]
|
| 194 |
+
>>> cs = fourier(cosh, I, 9)
|
| 195 |
+
>>> g = lambda x: (cosh(x) - fourierval(cs, I, x))**2
|
| 196 |
+
>>> nprint(sqrt(quad(g, I)))
|
| 197 |
+
0.00467963
|
| 198 |
+
|
| 199 |
+
:func:`~mpmath.fourier` uses numerical quadrature. For nonsmooth functions,
|
| 200 |
+
the accuracy (and speed) can be improved by including all singular
|
| 201 |
+
points in the interval specification::
|
| 202 |
+
|
| 203 |
+
>>> nprint(fourier(abs, [-1, 1], 0), 10)
|
| 204 |
+
([0.5000441648], [0.0])
|
| 205 |
+
>>> nprint(fourier(abs, [-1, 0, 1], 0), 10)
|
| 206 |
+
([0.5], [0.0])
|
| 207 |
+
|
| 208 |
+
"""
|
| 209 |
+
interval = ctx._as_points(interval)
|
| 210 |
+
a = interval[0]
|
| 211 |
+
b = interval[-1]
|
| 212 |
+
L = b-a
|
| 213 |
+
cos_series = []
|
| 214 |
+
sin_series = []
|
| 215 |
+
cutoff = ctx.eps*10
|
| 216 |
+
for n in xrange(N+1):
|
| 217 |
+
m = 2*n*ctx.pi/L
|
| 218 |
+
an = 2*ctx.quadgl(lambda t: f(t)*ctx.cos(m*t), interval)/L
|
| 219 |
+
bn = 2*ctx.quadgl(lambda t: f(t)*ctx.sin(m*t), interval)/L
|
| 220 |
+
if n == 0:
|
| 221 |
+
an /= 2
|
| 222 |
+
if abs(an) < cutoff: an = ctx.zero
|
| 223 |
+
if abs(bn) < cutoff: bn = ctx.zero
|
| 224 |
+
cos_series.append(an)
|
| 225 |
+
sin_series.append(bn)
|
| 226 |
+
return cos_series, sin_series
|
| 227 |
+
|
| 228 |
+
@defun
|
| 229 |
+
def fourierval(ctx, series, interval, x):
|
| 230 |
+
"""
|
| 231 |
+
Evaluates a Fourier series (in the format computed by
|
| 232 |
+
by :func:`~mpmath.fourier` for the given interval) at the point `x`.
|
| 233 |
+
|
| 234 |
+
The series should be a pair `(c, s)` where `c` is the
|
| 235 |
+
cosine series and `s` is the sine series. The two lists
|
| 236 |
+
need not have the same length.
|
| 237 |
+
"""
|
| 238 |
+
cs, ss = series
|
| 239 |
+
ab = ctx._as_points(interval)
|
| 240 |
+
a = interval[0]
|
| 241 |
+
b = interval[-1]
|
| 242 |
+
m = 2*ctx.pi/(ab[-1]-ab[0])
|
| 243 |
+
s = ctx.zero
|
| 244 |
+
s += ctx.fsum(cs[n]*ctx.cos(m*n*x) for n in xrange(len(cs)) if cs[n])
|
| 245 |
+
s += ctx.fsum(ss[n]*ctx.sin(m*n*x) for n in xrange(len(ss)) if ss[n])
|
| 246 |
+
return s
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import eigen # to set methods
|
| 2 |
+
from . import eigen_symmetric # to set methods
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/__pycache__/calculus.cpython-311.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/calculus.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..libmp.backend import xrange
|
| 2 |
+
|
| 3 |
+
# TODO: should use diagonalization-based algorithms
|
| 4 |
+
|
| 5 |
+
class MatrixCalculusMethods(object):
|
| 6 |
+
|
| 7 |
+
def _exp_pade(ctx, a):
|
| 8 |
+
"""
|
| 9 |
+
Exponential of a matrix using Pade approximants.
|
| 10 |
+
|
| 11 |
+
See G. H. Golub, C. F. van Loan 'Matrix Computations',
|
| 12 |
+
third Ed., page 572
|
| 13 |
+
|
| 14 |
+
TODO:
|
| 15 |
+
- find a good estimate for q
|
| 16 |
+
- reduce the number of matrix multiplications to improve
|
| 17 |
+
performance
|
| 18 |
+
"""
|
| 19 |
+
def eps_pade(p):
|
| 20 |
+
return ctx.mpf(2)**(3-2*p) * \
|
| 21 |
+
ctx.factorial(p)**2/(ctx.factorial(2*p)**2 * (2*p + 1))
|
| 22 |
+
q = 4
|
| 23 |
+
extraq = 8
|
| 24 |
+
while 1:
|
| 25 |
+
if eps_pade(q) < ctx.eps:
|
| 26 |
+
break
|
| 27 |
+
q += 1
|
| 28 |
+
q += extraq
|
| 29 |
+
j = int(max(1, ctx.mag(ctx.mnorm(a,'inf'))))
|
| 30 |
+
extra = q
|
| 31 |
+
prec = ctx.prec
|
| 32 |
+
ctx.dps += extra + 3
|
| 33 |
+
try:
|
| 34 |
+
a = a/2**j
|
| 35 |
+
na = a.rows
|
| 36 |
+
den = ctx.eye(na)
|
| 37 |
+
num = ctx.eye(na)
|
| 38 |
+
x = ctx.eye(na)
|
| 39 |
+
c = ctx.mpf(1)
|
| 40 |
+
for k in range(1, q+1):
|
| 41 |
+
c *= ctx.mpf(q - k + 1)/((2*q - k + 1) * k)
|
| 42 |
+
x = a*x
|
| 43 |
+
cx = c*x
|
| 44 |
+
num += cx
|
| 45 |
+
den += (-1)**k * cx
|
| 46 |
+
f = ctx.lu_solve_mat(den, num)
|
| 47 |
+
for k in range(j):
|
| 48 |
+
f = f*f
|
| 49 |
+
finally:
|
| 50 |
+
ctx.prec = prec
|
| 51 |
+
return f*1
|
| 52 |
+
|
| 53 |
+
def expm(ctx, A, method='taylor'):
|
| 54 |
+
r"""
|
| 55 |
+
Computes the matrix exponential of a square matrix `A`, which is defined
|
| 56 |
+
by the power series
|
| 57 |
+
|
| 58 |
+
.. math ::
|
| 59 |
+
|
| 60 |
+
\exp(A) = I + A + \frac{A^2}{2!} + \frac{A^3}{3!} + \ldots
|
| 61 |
+
|
| 62 |
+
With method='taylor', the matrix exponential is computed
|
| 63 |
+
using the Taylor series. With method='pade', Pade approximants
|
| 64 |
+
are used instead.
|
| 65 |
+
|
| 66 |
+
**Examples**
|
| 67 |
+
|
| 68 |
+
Basic examples::
|
| 69 |
+
|
| 70 |
+
>>> from mpmath import *
|
| 71 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 72 |
+
>>> expm(zeros(3))
|
| 73 |
+
[1.0 0.0 0.0]
|
| 74 |
+
[0.0 1.0 0.0]
|
| 75 |
+
[0.0 0.0 1.0]
|
| 76 |
+
>>> expm(eye(3))
|
| 77 |
+
[2.71828182845905 0.0 0.0]
|
| 78 |
+
[ 0.0 2.71828182845905 0.0]
|
| 79 |
+
[ 0.0 0.0 2.71828182845905]
|
| 80 |
+
>>> expm([[1,1,0],[1,0,1],[0,1,0]])
|
| 81 |
+
[ 3.86814500615414 2.26812870852145 0.841130841230196]
|
| 82 |
+
[ 2.26812870852145 2.44114713886289 1.42699786729125]
|
| 83 |
+
[0.841130841230196 1.42699786729125 1.6000162976327]
|
| 84 |
+
>>> expm([[1,1,0],[1,0,1],[0,1,0]], method='pade')
|
| 85 |
+
[ 3.86814500615414 2.26812870852145 0.841130841230196]
|
| 86 |
+
[ 2.26812870852145 2.44114713886289 1.42699786729125]
|
| 87 |
+
[0.841130841230196 1.42699786729125 1.6000162976327]
|
| 88 |
+
>>> expm([[1+j, 0], [1+j,1]])
|
| 89 |
+
[(1.46869393991589 + 2.28735528717884j) 0.0]
|
| 90 |
+
[ (1.03776739863568 + 3.536943175722j) (2.71828182845905 + 0.0j)]
|
| 91 |
+
|
| 92 |
+
Matrices with large entries are allowed::
|
| 93 |
+
|
| 94 |
+
>>> expm(matrix([[1,2],[2,3]])**25)
|
| 95 |
+
[5.65024064048415e+2050488462815550 9.14228140091932e+2050488462815550]
|
| 96 |
+
[9.14228140091932e+2050488462815550 1.47925220414035e+2050488462815551]
|
| 97 |
+
|
| 98 |
+
The identity `\exp(A+B) = \exp(A) \exp(B)` does not hold for
|
| 99 |
+
noncommuting matrices::
|
| 100 |
+
|
| 101 |
+
>>> A = hilbert(3)
|
| 102 |
+
>>> B = A + eye(3)
|
| 103 |
+
>>> chop(mnorm(A*B - B*A))
|
| 104 |
+
0.0
|
| 105 |
+
>>> chop(mnorm(expm(A+B) - expm(A)*expm(B)))
|
| 106 |
+
0.0
|
| 107 |
+
>>> B = A + ones(3)
|
| 108 |
+
>>> mnorm(A*B - B*A)
|
| 109 |
+
1.8
|
| 110 |
+
>>> mnorm(expm(A+B) - expm(A)*expm(B))
|
| 111 |
+
42.0927851137247
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
if method == 'pade':
|
| 115 |
+
prec = ctx.prec
|
| 116 |
+
try:
|
| 117 |
+
A = ctx.matrix(A)
|
| 118 |
+
ctx.prec += 2*A.rows
|
| 119 |
+
res = ctx._exp_pade(A)
|
| 120 |
+
finally:
|
| 121 |
+
ctx.prec = prec
|
| 122 |
+
return res
|
| 123 |
+
A = ctx.matrix(A)
|
| 124 |
+
prec = ctx.prec
|
| 125 |
+
j = int(max(1, ctx.mag(ctx.mnorm(A,'inf'))))
|
| 126 |
+
j += int(0.5*prec**0.5)
|
| 127 |
+
try:
|
| 128 |
+
ctx.prec += 10 + 2*j
|
| 129 |
+
tol = +ctx.eps
|
| 130 |
+
A = A/2**j
|
| 131 |
+
T = A
|
| 132 |
+
Y = A**0 + A
|
| 133 |
+
k = 2
|
| 134 |
+
while 1:
|
| 135 |
+
T *= A * (1/ctx.mpf(k))
|
| 136 |
+
if ctx.mnorm(T, 'inf') < tol:
|
| 137 |
+
break
|
| 138 |
+
Y += T
|
| 139 |
+
k += 1
|
| 140 |
+
for k in xrange(j):
|
| 141 |
+
Y = Y*Y
|
| 142 |
+
finally:
|
| 143 |
+
ctx.prec = prec
|
| 144 |
+
Y *= 1
|
| 145 |
+
return Y
|
| 146 |
+
|
| 147 |
+
def cosm(ctx, A):
|
| 148 |
+
r"""
|
| 149 |
+
Gives the cosine of a square matrix `A`, defined in analogy
|
| 150 |
+
with the matrix exponential.
|
| 151 |
+
|
| 152 |
+
Examples::
|
| 153 |
+
|
| 154 |
+
>>> from mpmath import *
|
| 155 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 156 |
+
>>> X = eye(3)
|
| 157 |
+
>>> cosm(X)
|
| 158 |
+
[0.54030230586814 0.0 0.0]
|
| 159 |
+
[ 0.0 0.54030230586814 0.0]
|
| 160 |
+
[ 0.0 0.0 0.54030230586814]
|
| 161 |
+
>>> X = hilbert(3)
|
| 162 |
+
>>> cosm(X)
|
| 163 |
+
[ 0.424403834569555 -0.316643413047167 -0.221474945949293]
|
| 164 |
+
[-0.316643413047167 0.820646708837824 -0.127183694770039]
|
| 165 |
+
[-0.221474945949293 -0.127183694770039 0.909236687217541]
|
| 166 |
+
>>> X = matrix([[1+j,-2],[0,-j]])
|
| 167 |
+
>>> cosm(X)
|
| 168 |
+
[(0.833730025131149 - 0.988897705762865j) (1.07485840848393 - 0.17192140544213j)]
|
| 169 |
+
[ 0.0 (1.54308063481524 + 0.0j)]
|
| 170 |
+
"""
|
| 171 |
+
B = 0.5 * (ctx.expm(A*ctx.j) + ctx.expm(A*(-ctx.j)))
|
| 172 |
+
if not sum(A.apply(ctx.im).apply(abs)):
|
| 173 |
+
B = B.apply(ctx.re)
|
| 174 |
+
return B
|
| 175 |
+
|
| 176 |
+
def sinm(ctx, A):
|
| 177 |
+
r"""
|
| 178 |
+
Gives the sine of a square matrix `A`, defined in analogy
|
| 179 |
+
with the matrix exponential.
|
| 180 |
+
|
| 181 |
+
Examples::
|
| 182 |
+
|
| 183 |
+
>>> from mpmath import *
|
| 184 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 185 |
+
>>> X = eye(3)
|
| 186 |
+
>>> sinm(X)
|
| 187 |
+
[0.841470984807897 0.0 0.0]
|
| 188 |
+
[ 0.0 0.841470984807897 0.0]
|
| 189 |
+
[ 0.0 0.0 0.841470984807897]
|
| 190 |
+
>>> X = hilbert(3)
|
| 191 |
+
>>> sinm(X)
|
| 192 |
+
[0.711608512150994 0.339783913247439 0.220742837314741]
|
| 193 |
+
[0.339783913247439 0.244113865695532 0.187231271174372]
|
| 194 |
+
[0.220742837314741 0.187231271174372 0.155816730769635]
|
| 195 |
+
>>> X = matrix([[1+j,-2],[0,-j]])
|
| 196 |
+
>>> sinm(X)
|
| 197 |
+
[(1.29845758141598 + 0.634963914784736j) (-1.96751511930922 + 0.314700021761367j)]
|
| 198 |
+
[ 0.0 (0.0 - 1.1752011936438j)]
|
| 199 |
+
"""
|
| 200 |
+
B = (-0.5j) * (ctx.expm(A*ctx.j) - ctx.expm(A*(-ctx.j)))
|
| 201 |
+
if not sum(A.apply(ctx.im).apply(abs)):
|
| 202 |
+
B = B.apply(ctx.re)
|
| 203 |
+
return B
|
| 204 |
+
|
| 205 |
+
def _sqrtm_rot(ctx, A, _may_rotate):
|
| 206 |
+
# If the iteration fails to converge, cheat by performing
|
| 207 |
+
# a rotation by a complex number
|
| 208 |
+
u = ctx.j**0.3
|
| 209 |
+
return ctx.sqrtm(u*A, _may_rotate) / ctx.sqrt(u)
|
| 210 |
+
|
| 211 |
+
def sqrtm(ctx, A, _may_rotate=2):
|
| 212 |
+
r"""
|
| 213 |
+
Computes a square root of the square matrix `A`, i.e. returns
|
| 214 |
+
a matrix `B = A^{1/2}` such that `B^2 = A`. The square root
|
| 215 |
+
of a matrix, if it exists, is not unique.
|
| 216 |
+
|
| 217 |
+
**Examples**
|
| 218 |
+
|
| 219 |
+
Square roots of some simple matrices::
|
| 220 |
+
|
| 221 |
+
>>> from mpmath import *
|
| 222 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 223 |
+
>>> sqrtm([[1,0], [0,1]])
|
| 224 |
+
[1.0 0.0]
|
| 225 |
+
[0.0 1.0]
|
| 226 |
+
>>> sqrtm([[0,0], [0,0]])
|
| 227 |
+
[0.0 0.0]
|
| 228 |
+
[0.0 0.0]
|
| 229 |
+
>>> sqrtm([[2,0],[0,1]])
|
| 230 |
+
[1.4142135623731 0.0]
|
| 231 |
+
[ 0.0 1.0]
|
| 232 |
+
>>> sqrtm([[1,1],[1,0]])
|
| 233 |
+
[ (0.920442065259926 - 0.21728689675164j) (0.568864481005783 + 0.351577584254143j)]
|
| 234 |
+
[(0.568864481005783 + 0.351577584254143j) (0.351577584254143 - 0.568864481005783j)]
|
| 235 |
+
>>> sqrtm([[1,0],[0,1]])
|
| 236 |
+
[1.0 0.0]
|
| 237 |
+
[0.0 1.0]
|
| 238 |
+
>>> sqrtm([[-1,0],[0,1]])
|
| 239 |
+
[(0.0 - 1.0j) 0.0]
|
| 240 |
+
[ 0.0 (1.0 + 0.0j)]
|
| 241 |
+
>>> sqrtm([[j,0],[0,j]])
|
| 242 |
+
[(0.707106781186547 + 0.707106781186547j) 0.0]
|
| 243 |
+
[ 0.0 (0.707106781186547 + 0.707106781186547j)]
|
| 244 |
+
|
| 245 |
+
A square root of a rotation matrix, giving the corresponding
|
| 246 |
+
half-angle rotation matrix::
|
| 247 |
+
|
| 248 |
+
>>> t1 = 0.75
|
| 249 |
+
>>> t2 = t1 * 0.5
|
| 250 |
+
>>> A1 = matrix([[cos(t1), -sin(t1)], [sin(t1), cos(t1)]])
|
| 251 |
+
>>> A2 = matrix([[cos(t2), -sin(t2)], [sin(t2), cos(t2)]])
|
| 252 |
+
>>> sqrtm(A1)
|
| 253 |
+
[0.930507621912314 -0.366272529086048]
|
| 254 |
+
[0.366272529086048 0.930507621912314]
|
| 255 |
+
>>> A2
|
| 256 |
+
[0.930507621912314 -0.366272529086048]
|
| 257 |
+
[0.366272529086048 0.930507621912314]
|
| 258 |
+
|
| 259 |
+
The identity `(A^2)^{1/2} = A` does not necessarily hold::
|
| 260 |
+
|
| 261 |
+
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
|
| 262 |
+
>>> sqrtm(A**2)
|
| 263 |
+
[ 4.0 1.0 4.0]
|
| 264 |
+
[ 7.0 8.0 9.0]
|
| 265 |
+
[10.0 2.0 11.0]
|
| 266 |
+
>>> sqrtm(A)**2
|
| 267 |
+
[ 4.0 1.0 4.0]
|
| 268 |
+
[ 7.0 8.0 9.0]
|
| 269 |
+
[10.0 2.0 11.0]
|
| 270 |
+
>>> A = matrix([[-4,1,4],[7,-8,9],[10,2,11]])
|
| 271 |
+
>>> sqrtm(A**2)
|
| 272 |
+
[ 7.43715112194995 -0.324127569985474 1.8481718827526]
|
| 273 |
+
[-0.251549715716942 9.32699765900402 2.48221180985147]
|
| 274 |
+
[ 4.11609388833616 0.775751877098258 13.017955697342]
|
| 275 |
+
>>> chop(sqrtm(A)**2)
|
| 276 |
+
[-4.0 1.0 4.0]
|
| 277 |
+
[ 7.0 -8.0 9.0]
|
| 278 |
+
[10.0 2.0 11.0]
|
| 279 |
+
|
| 280 |
+
For some matrices, a square root does not exist::
|
| 281 |
+
|
| 282 |
+
>>> sqrtm([[0,1], [0,0]])
|
| 283 |
+
Traceback (most recent call last):
|
| 284 |
+
...
|
| 285 |
+
ZeroDivisionError: matrix is numerically singular
|
| 286 |
+
|
| 287 |
+
Two examples from the documentation for Matlab's ``sqrtm``::
|
| 288 |
+
|
| 289 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 290 |
+
>>> sqrtm([[7,10],[15,22]])
|
| 291 |
+
[1.56669890360128 1.74077655955698]
|
| 292 |
+
[2.61116483933547 4.17786374293675]
|
| 293 |
+
>>>
|
| 294 |
+
>>> X = matrix(\
|
| 295 |
+
... [[5,-4,1,0,0],
|
| 296 |
+
... [-4,6,-4,1,0],
|
| 297 |
+
... [1,-4,6,-4,1],
|
| 298 |
+
... [0,1,-4,6,-4],
|
| 299 |
+
... [0,0,1,-4,5]])
|
| 300 |
+
>>> Y = matrix(\
|
| 301 |
+
... [[2,-1,-0,-0,-0],
|
| 302 |
+
... [-1,2,-1,0,-0],
|
| 303 |
+
... [0,-1,2,-1,0],
|
| 304 |
+
... [-0,0,-1,2,-1],
|
| 305 |
+
... [-0,-0,-0,-1,2]])
|
| 306 |
+
>>> mnorm(sqrtm(X) - Y)
|
| 307 |
+
4.53155328326114e-19
|
| 308 |
+
|
| 309 |
+
"""
|
| 310 |
+
A = ctx.matrix(A)
|
| 311 |
+
# Trivial
|
| 312 |
+
if A*0 == A:
|
| 313 |
+
return A
|
| 314 |
+
prec = ctx.prec
|
| 315 |
+
if _may_rotate:
|
| 316 |
+
d = ctx.det(A)
|
| 317 |
+
if abs(ctx.im(d)) < 16*ctx.eps and ctx.re(d) < 0:
|
| 318 |
+
return ctx._sqrtm_rot(A, _may_rotate-1)
|
| 319 |
+
try:
|
| 320 |
+
ctx.prec += 10
|
| 321 |
+
tol = ctx.eps * 128
|
| 322 |
+
Y = A
|
| 323 |
+
Z = I = A**0
|
| 324 |
+
k = 0
|
| 325 |
+
# Denman-Beavers iteration
|
| 326 |
+
while 1:
|
| 327 |
+
Yprev = Y
|
| 328 |
+
try:
|
| 329 |
+
Y, Z = 0.5*(Y+ctx.inverse(Z)), 0.5*(Z+ctx.inverse(Y))
|
| 330 |
+
except ZeroDivisionError:
|
| 331 |
+
if _may_rotate:
|
| 332 |
+
Y = ctx._sqrtm_rot(A, _may_rotate-1)
|
| 333 |
+
break
|
| 334 |
+
else:
|
| 335 |
+
raise
|
| 336 |
+
mag1 = ctx.mnorm(Y-Yprev, 'inf')
|
| 337 |
+
mag2 = ctx.mnorm(Y, 'inf')
|
| 338 |
+
if mag1 <= mag2*tol:
|
| 339 |
+
break
|
| 340 |
+
if _may_rotate and k > 6 and not mag1 < mag2 * 0.001:
|
| 341 |
+
return ctx._sqrtm_rot(A, _may_rotate-1)
|
| 342 |
+
k += 1
|
| 343 |
+
if k > ctx.prec:
|
| 344 |
+
raise ctx.NoConvergence
|
| 345 |
+
finally:
|
| 346 |
+
ctx.prec = prec
|
| 347 |
+
Y *= 1
|
| 348 |
+
return Y
|
| 349 |
+
|
| 350 |
+
def logm(ctx, A):
|
| 351 |
+
r"""
|
| 352 |
+
Computes a logarithm of the square matrix `A`, i.e. returns
|
| 353 |
+
a matrix `B = \log(A)` such that `\exp(B) = A`. The logarithm
|
| 354 |
+
of a matrix, if it exists, is not unique.
|
| 355 |
+
|
| 356 |
+
**Examples**
|
| 357 |
+
|
| 358 |
+
Logarithms of some simple matrices::
|
| 359 |
+
|
| 360 |
+
>>> from mpmath import *
|
| 361 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 362 |
+
>>> X = eye(3)
|
| 363 |
+
>>> logm(X)
|
| 364 |
+
[0.0 0.0 0.0]
|
| 365 |
+
[0.0 0.0 0.0]
|
| 366 |
+
[0.0 0.0 0.0]
|
| 367 |
+
>>> logm(2*X)
|
| 368 |
+
[0.693147180559945 0.0 0.0]
|
| 369 |
+
[ 0.0 0.693147180559945 0.0]
|
| 370 |
+
[ 0.0 0.0 0.693147180559945]
|
| 371 |
+
>>> logm(expm(X))
|
| 372 |
+
[1.0 0.0 0.0]
|
| 373 |
+
[0.0 1.0 0.0]
|
| 374 |
+
[0.0 0.0 1.0]
|
| 375 |
+
|
| 376 |
+
A logarithm of a complex matrix::
|
| 377 |
+
|
| 378 |
+
>>> X = matrix([[2+j, 1, 3], [1-j, 1-2*j, 1], [-4, -5, j]])
|
| 379 |
+
>>> B = logm(X)
|
| 380 |
+
>>> nprint(B)
|
| 381 |
+
[ (0.808757 + 0.107759j) (2.20752 + 0.202762j) (1.07376 - 0.773874j)]
|
| 382 |
+
[ (0.905709 - 0.107795j) (0.0287395 - 0.824993j) (0.111619 + 0.514272j)]
|
| 383 |
+
[(-0.930151 + 0.399512j) (-2.06266 - 0.674397j) (0.791552 + 0.519839j)]
|
| 384 |
+
>>> chop(expm(B))
|
| 385 |
+
[(2.0 + 1.0j) 1.0 3.0]
|
| 386 |
+
[(1.0 - 1.0j) (1.0 - 2.0j) 1.0]
|
| 387 |
+
[ -4.0 -5.0 (0.0 + 1.0j)]
|
| 388 |
+
|
| 389 |
+
A matrix `X` close to the identity matrix, for which
|
| 390 |
+
`\log(\exp(X)) = \exp(\log(X)) = X` holds::
|
| 391 |
+
|
| 392 |
+
>>> X = eye(3) + hilbert(3)/4
|
| 393 |
+
>>> X
|
| 394 |
+
[ 1.25 0.125 0.0833333333333333]
|
| 395 |
+
[ 0.125 1.08333333333333 0.0625]
|
| 396 |
+
[0.0833333333333333 0.0625 1.05]
|
| 397 |
+
>>> logm(expm(X))
|
| 398 |
+
[ 1.25 0.125 0.0833333333333333]
|
| 399 |
+
[ 0.125 1.08333333333333 0.0625]
|
| 400 |
+
[0.0833333333333333 0.0625 1.05]
|
| 401 |
+
>>> expm(logm(X))
|
| 402 |
+
[ 1.25 0.125 0.0833333333333333]
|
| 403 |
+
[ 0.125 1.08333333333333 0.0625]
|
| 404 |
+
[0.0833333333333333 0.0625 1.05]
|
| 405 |
+
|
| 406 |
+
A logarithm of a rotation matrix, giving back the angle of
|
| 407 |
+
the rotation::
|
| 408 |
+
|
| 409 |
+
>>> t = 3.7
|
| 410 |
+
>>> A = matrix([[cos(t),sin(t)],[-sin(t),cos(t)]])
|
| 411 |
+
>>> chop(logm(A))
|
| 412 |
+
[ 0.0 -2.58318530717959]
|
| 413 |
+
[2.58318530717959 0.0]
|
| 414 |
+
>>> (2*pi-t)
|
| 415 |
+
2.58318530717959
|
| 416 |
+
|
| 417 |
+
For some matrices, a logarithm does not exist::
|
| 418 |
+
|
| 419 |
+
>>> logm([[1,0], [0,0]])
|
| 420 |
+
Traceback (most recent call last):
|
| 421 |
+
...
|
| 422 |
+
ZeroDivisionError: matrix is numerically singular
|
| 423 |
+
|
| 424 |
+
Logarithm of a matrix with large entries::
|
| 425 |
+
|
| 426 |
+
>>> logm(hilbert(3) * 10**20).apply(re)
|
| 427 |
+
[ 45.5597513593433 1.27721006042799 0.317662687717978]
|
| 428 |
+
[ 1.27721006042799 42.5222778973542 2.24003708791604]
|
| 429 |
+
[0.317662687717978 2.24003708791604 42.395212822267]
|
| 430 |
+
|
| 431 |
+
"""
|
| 432 |
+
A = ctx.matrix(A)
|
| 433 |
+
prec = ctx.prec
|
| 434 |
+
try:
|
| 435 |
+
ctx.prec += 10
|
| 436 |
+
tol = ctx.eps * 128
|
| 437 |
+
I = A**0
|
| 438 |
+
B = A
|
| 439 |
+
n = 0
|
| 440 |
+
while 1:
|
| 441 |
+
B = ctx.sqrtm(B)
|
| 442 |
+
n += 1
|
| 443 |
+
if ctx.mnorm(B-I, 'inf') < 0.125:
|
| 444 |
+
break
|
| 445 |
+
T = X = B-I
|
| 446 |
+
L = X*0
|
| 447 |
+
k = 1
|
| 448 |
+
while 1:
|
| 449 |
+
if k & 1:
|
| 450 |
+
L += T / k
|
| 451 |
+
else:
|
| 452 |
+
L -= T / k
|
| 453 |
+
T *= X
|
| 454 |
+
if ctx.mnorm(T, 'inf') < tol:
|
| 455 |
+
break
|
| 456 |
+
k += 1
|
| 457 |
+
if k > ctx.prec:
|
| 458 |
+
raise ctx.NoConvergence
|
| 459 |
+
finally:
|
| 460 |
+
ctx.prec = prec
|
| 461 |
+
L *= 2**n
|
| 462 |
+
return L
|
| 463 |
+
|
| 464 |
+
def powm(ctx, A, r):
|
| 465 |
+
r"""
|
| 466 |
+
Computes `A^r = \exp(A \log r)` for a matrix `A` and complex
|
| 467 |
+
number `r`.
|
| 468 |
+
|
| 469 |
+
**Examples**
|
| 470 |
+
|
| 471 |
+
Powers and inverse powers of a matrix::
|
| 472 |
+
|
| 473 |
+
>>> from mpmath import *
|
| 474 |
+
>>> mp.dps = 15; mp.pretty = True
|
| 475 |
+
>>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
|
| 476 |
+
>>> powm(A, 2)
|
| 477 |
+
[ 63.0 20.0 69.0]
|
| 478 |
+
[174.0 89.0 199.0]
|
| 479 |
+
[164.0 48.0 179.0]
|
| 480 |
+
>>> chop(powm(powm(A, 4), 1/4.))
|
| 481 |
+
[ 4.0 1.0 4.0]
|
| 482 |
+
[ 7.0 8.0 9.0]
|
| 483 |
+
[10.0 2.0 11.0]
|
| 484 |
+
>>> powm(extraprec(20)(powm)(A, -4), -1/4.)
|
| 485 |
+
[ 4.0 1.0 4.0]
|
| 486 |
+
[ 7.0 8.0 9.0]
|
| 487 |
+
[10.0 2.0 11.0]
|
| 488 |
+
>>> chop(powm(powm(A, 1+0.5j), 1/(1+0.5j)))
|
| 489 |
+
[ 4.0 1.0 4.0]
|
| 490 |
+
[ 7.0 8.0 9.0]
|
| 491 |
+
[10.0 2.0 11.0]
|
| 492 |
+
>>> powm(extraprec(5)(powm)(A, -1.5), -1/(1.5))
|
| 493 |
+
[ 4.0 1.0 4.0]
|
| 494 |
+
[ 7.0 8.0 9.0]
|
| 495 |
+
[10.0 2.0 11.0]
|
| 496 |
+
|
| 497 |
+
A Fibonacci-generating matrix::
|
| 498 |
+
|
| 499 |
+
>>> powm([[1,1],[1,0]], 10)
|
| 500 |
+
[89.0 55.0]
|
| 501 |
+
[55.0 34.0]
|
| 502 |
+
>>> fib(10)
|
| 503 |
+
55.0
|
| 504 |
+
>>> powm([[1,1],[1,0]], 6.5)
|
| 505 |
+
[(16.5166626964253 - 0.0121089837381789j) (10.2078589271083 + 0.0195927472575932j)]
|
| 506 |
+
[(10.2078589271083 + 0.0195927472575932j) (6.30880376931698 - 0.0317017309957721j)]
|
| 507 |
+
>>> (phi**6.5 - (1-phi)**6.5)/sqrt(5)
|
| 508 |
+
(10.2078589271083 - 0.0195927472575932j)
|
| 509 |
+
>>> powm([[1,1],[1,0]], 6.2)
|
| 510 |
+
[ (14.3076953002666 - 0.008222855781077j) (8.81733464837593 + 0.0133048601383712j)]
|
| 511 |
+
[(8.81733464837593 + 0.0133048601383712j) (5.49036065189071 - 0.0215277159194482j)]
|
| 512 |
+
>>> (phi**6.2 - (1-phi)**6.2)/sqrt(5)
|
| 513 |
+
(8.81733464837593 - 0.0133048601383712j)
|
| 514 |
+
|
| 515 |
+
"""
|
| 516 |
+
A = ctx.matrix(A)
|
| 517 |
+
r = ctx.convert(r)
|
| 518 |
+
prec = ctx.prec
|
| 519 |
+
try:
|
| 520 |
+
ctx.prec += 10
|
| 521 |
+
if ctx.isint(r):
|
| 522 |
+
v = A ** int(r)
|
| 523 |
+
elif ctx.isint(r*2):
|
| 524 |
+
y = int(r*2)
|
| 525 |
+
v = ctx.sqrtm(A) ** y
|
| 526 |
+
else:
|
| 527 |
+
v = ctx.expm(r*ctx.logm(A))
|
| 528 |
+
finally:
|
| 529 |
+
ctx.prec = prec
|
| 530 |
+
v *= 1
|
| 531 |
+
return v
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/eigen.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
##################################################################################################
|
| 5 |
+
# module for the eigenvalue problem
|
| 6 |
+
# Copyright 2013 Timo Hartmann (thartmann15 at gmail.com)
|
| 7 |
+
#
|
| 8 |
+
# todo:
|
| 9 |
+
# - implement balancing
|
| 10 |
+
# - agressive early deflation
|
| 11 |
+
#
|
| 12 |
+
##################################################################################################
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
The eigenvalue problem
|
| 16 |
+
----------------------
|
| 17 |
+
|
| 18 |
+
This file contains routines for the eigenvalue problem.
|
| 19 |
+
|
| 20 |
+
high level routines:
|
| 21 |
+
|
| 22 |
+
hessenberg : reduction of a real or complex square matrix to upper Hessenberg form
|
| 23 |
+
schur : reduction of a real or complex square matrix to upper Schur form
|
| 24 |
+
eig : eigenvalues and eigenvectors of a real or complex square matrix
|
| 25 |
+
|
| 26 |
+
low level routines:
|
| 27 |
+
|
| 28 |
+
hessenberg_reduce_0 : reduction of a real or complex square matrix to upper Hessenberg form
|
| 29 |
+
hessenberg_reduce_1 : auxiliary routine to hessenberg_reduce_0
|
| 30 |
+
qr_step : a single implicitly shifted QR step for an upper Hessenberg matrix
|
| 31 |
+
hessenberg_qr : Schur decomposition of an upper Hessenberg matrix
|
| 32 |
+
eig_tr_r : right eigenvectors of an upper triangular matrix
|
| 33 |
+
eig_tr_l : left eigenvectors of an upper triangular matrix
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from ..libmp.backend import xrange
|
| 37 |
+
|
| 38 |
+
class Eigen(object):
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
def defun(f):
|
| 42 |
+
setattr(Eigen, f.__name__, f)
|
| 43 |
+
return f
|
| 44 |
+
|
| 45 |
+
def hessenberg_reduce_0(ctx, A, T):
|
| 46 |
+
"""
|
| 47 |
+
This routine computes the (upper) Hessenberg decomposition of a square matrix A.
|
| 48 |
+
Given A, an unitary matrix Q is calculated such that
|
| 49 |
+
|
| 50 |
+
Q' A Q = H and Q' Q = Q Q' = 1
|
| 51 |
+
|
| 52 |
+
where H is an upper Hessenberg matrix, meaning that it only contains zeros
|
| 53 |
+
below the first subdiagonal. Here ' denotes the hermitian transpose (i.e.
|
| 54 |
+
transposition and conjugation).
|
| 55 |
+
|
| 56 |
+
parameters:
|
| 57 |
+
A (input/output) On input, A contains the square matrix A of
|
| 58 |
+
dimension (n,n). On output, A contains a compressed representation
|
| 59 |
+
of Q and H.
|
| 60 |
+
T (output) An array of length n containing the first elements of
|
| 61 |
+
the Householder reflectors.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
# internally we work with householder reflections from the right.
|
| 65 |
+
# let u be a row vector (i.e. u[i]=A[i,:i]). then
|
| 66 |
+
# Q is build up by reflectors of the type (1-v'v) where v is a suitable
|
| 67 |
+
# modification of u. these reflectors are applyed to A from the right.
|
| 68 |
+
# because we work with reflectors from the right we have to start with
|
| 69 |
+
# the bottom row of A and work then upwards (this corresponds to
|
| 70 |
+
# some kind of RQ decomposition).
|
| 71 |
+
# the first part of the vectors v (i.e. A[i,:(i-1)]) are stored as row vectors
|
| 72 |
+
# in the lower left part of A (excluding the diagonal and subdiagonal).
|
| 73 |
+
# the last entry of v is stored in T.
|
| 74 |
+
# the upper right part of A (including diagonal and subdiagonal) becomes H.
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
n = A.rows
|
| 78 |
+
if n <= 2: return
|
| 79 |
+
|
| 80 |
+
for i in xrange(n-1, 1, -1):
|
| 81 |
+
|
| 82 |
+
# scale the vector
|
| 83 |
+
|
| 84 |
+
scale = 0
|
| 85 |
+
for k in xrange(0, i):
|
| 86 |
+
scale += abs(ctx.re(A[i,k])) + abs(ctx.im(A[i,k]))
|
| 87 |
+
|
| 88 |
+
scale_inv = 0
|
| 89 |
+
if scale != 0:
|
| 90 |
+
scale_inv = 1 / scale
|
| 91 |
+
|
| 92 |
+
if scale == 0 or ctx.isinf(scale_inv):
|
| 93 |
+
# sadly there are floating point numbers not equal to zero whose reciprocal is infinity
|
| 94 |
+
T[i] = 0
|
| 95 |
+
A[i,i-1] = 0
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
# calculate parameters for housholder transformation
|
| 99 |
+
|
| 100 |
+
H = 0
|
| 101 |
+
for k in xrange(0, i):
|
| 102 |
+
A[i,k] *= scale_inv
|
| 103 |
+
rr = ctx.re(A[i,k])
|
| 104 |
+
ii = ctx.im(A[i,k])
|
| 105 |
+
H += rr * rr + ii * ii
|
| 106 |
+
|
| 107 |
+
F = A[i,i-1]
|
| 108 |
+
f = abs(F)
|
| 109 |
+
G = ctx.sqrt(H)
|
| 110 |
+
A[i,i-1] = - G * scale
|
| 111 |
+
|
| 112 |
+
if f == 0:
|
| 113 |
+
T[i] = G
|
| 114 |
+
else:
|
| 115 |
+
ff = F / f
|
| 116 |
+
T[i] = F + G * ff
|
| 117 |
+
A[i,i-1] *= ff
|
| 118 |
+
|
| 119 |
+
H += G * f
|
| 120 |
+
H = 1 / ctx.sqrt(H)
|
| 121 |
+
|
| 122 |
+
T[i] *= H
|
| 123 |
+
for k in xrange(0, i - 1):
|
| 124 |
+
A[i,k] *= H
|
| 125 |
+
|
| 126 |
+
for j in xrange(0, i):
|
| 127 |
+
# apply housholder transformation (from right)
|
| 128 |
+
|
| 129 |
+
G = ctx.conj(T[i]) * A[j,i-1]
|
| 130 |
+
for k in xrange(0, i-1):
|
| 131 |
+
G += ctx.conj(A[i,k]) * A[j,k]
|
| 132 |
+
|
| 133 |
+
A[j,i-1] -= G * T[i]
|
| 134 |
+
for k in xrange(0, i-1):
|
| 135 |
+
A[j,k] -= G * A[i,k]
|
| 136 |
+
|
| 137 |
+
for j in xrange(0, n):
|
| 138 |
+
# apply housholder transformation (from left)
|
| 139 |
+
|
| 140 |
+
G = T[i] * A[i-1,j]
|
| 141 |
+
for k in xrange(0, i-1):
|
| 142 |
+
G += A[i,k] * A[k,j]
|
| 143 |
+
|
| 144 |
+
A[i-1,j] -= G * ctx.conj(T[i])
|
| 145 |
+
for k in xrange(0, i-1):
|
| 146 |
+
A[k,j] -= G * ctx.conj(A[i,k])
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def hessenberg_reduce_1(ctx, A, T):
|
| 151 |
+
"""
|
| 152 |
+
This routine forms the unitary matrix Q described in hessenberg_reduce_0.
|
| 153 |
+
|
| 154 |
+
parameters:
|
| 155 |
+
A (input/output) On input, A is the same matrix as delivered by
|
| 156 |
+
hessenberg_reduce_0. On output, A is set to Q.
|
| 157 |
+
|
| 158 |
+
T (input) On input, T is the same array as delivered by hessenberg_reduce_0.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
n = A.rows
|
| 162 |
+
|
| 163 |
+
if n == 1:
|
| 164 |
+
A[0,0] = 1
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
A[0,0] = A[1,1] = 1
|
| 168 |
+
A[0,1] = A[1,0] = 0
|
| 169 |
+
|
| 170 |
+
for i in xrange(2, n):
|
| 171 |
+
if T[i] != 0:
|
| 172 |
+
|
| 173 |
+
for j in xrange(0, i):
|
| 174 |
+
G = T[i] * A[i-1,j]
|
| 175 |
+
for k in xrange(0, i-1):
|
| 176 |
+
G += A[i,k] * A[k,j]
|
| 177 |
+
|
| 178 |
+
A[i-1,j] -= G * ctx.conj(T[i])
|
| 179 |
+
for k in xrange(0, i-1):
|
| 180 |
+
A[k,j] -= G * ctx.conj(A[i,k])
|
| 181 |
+
|
| 182 |
+
A[i,i] = 1
|
| 183 |
+
for j in xrange(0, i):
|
| 184 |
+
A[j,i] = A[i,j] = 0
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@defun
|
| 189 |
+
def hessenberg(ctx, A, overwrite_a = False):
|
| 190 |
+
"""
|
| 191 |
+
This routine computes the Hessenberg decomposition of a square matrix A.
|
| 192 |
+
Given A, an unitary matrix Q is determined such that
|
| 193 |
+
|
| 194 |
+
Q' A Q = H and Q' Q = Q Q' = 1
|
| 195 |
+
|
| 196 |
+
where H is an upper right Hessenberg matrix. Here ' denotes the hermitian
|
| 197 |
+
transpose (i.e. transposition and conjugation).
|
| 198 |
+
|
| 199 |
+
input:
|
| 200 |
+
A : a real or complex square matrix
|
| 201 |
+
overwrite_a : if true, allows modification of A which may improve
|
| 202 |
+
performance. if false, A is not modified.
|
| 203 |
+
|
| 204 |
+
output:
|
| 205 |
+
Q : an unitary matrix
|
| 206 |
+
H : an upper right Hessenberg matrix
|
| 207 |
+
|
| 208 |
+
example:
|
| 209 |
+
>>> from mpmath import mp
|
| 210 |
+
>>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
|
| 211 |
+
>>> Q, H = mp.hessenberg(A)
|
| 212 |
+
>>> mp.nprint(H, 3) # doctest:+SKIP
|
| 213 |
+
[ 3.15 2.23 4.44]
|
| 214 |
+
[-0.769 4.85 3.05]
|
| 215 |
+
[ 0.0 3.61 7.0]
|
| 216 |
+
>>> print(mp.chop(A - Q * H * Q.transpose_conj()))
|
| 217 |
+
[0.0 0.0 0.0]
|
| 218 |
+
[0.0 0.0 0.0]
|
| 219 |
+
[0.0 0.0 0.0]
|
| 220 |
+
|
| 221 |
+
return value: (Q, H)
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
n = A.rows
|
| 225 |
+
|
| 226 |
+
if n == 1:
|
| 227 |
+
return (ctx.matrix([[1]]), A)
|
| 228 |
+
|
| 229 |
+
if not overwrite_a:
|
| 230 |
+
A = A.copy()
|
| 231 |
+
|
| 232 |
+
T = ctx.matrix(n, 1)
|
| 233 |
+
|
| 234 |
+
hessenberg_reduce_0(ctx, A, T)
|
| 235 |
+
Q = A.copy()
|
| 236 |
+
hessenberg_reduce_1(ctx, Q, T)
|
| 237 |
+
|
| 238 |
+
for x in xrange(n):
|
| 239 |
+
for y in xrange(x+2, n):
|
| 240 |
+
A[y,x] = 0
|
| 241 |
+
|
| 242 |
+
return Q, A
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
###########################################################################
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def qr_step(ctx, n0, n1, A, Q, shift):
|
| 249 |
+
"""
|
| 250 |
+
This subroutine executes a single implicitly shifted QR step applied to an
|
| 251 |
+
upper Hessenberg matrix A. Given A and shift as input, first an QR
|
| 252 |
+
decomposition is calculated:
|
| 253 |
+
|
| 254 |
+
Q R = A - shift * 1 .
|
| 255 |
+
|
| 256 |
+
The output is then following matrix:
|
| 257 |
+
|
| 258 |
+
R Q + shift * 1
|
| 259 |
+
|
| 260 |
+
parameters:
|
| 261 |
+
n0, n1 (input) Two integers which specify the submatrix A[n0:n1,n0:n1]
|
| 262 |
+
on which this subroutine operators. The subdiagonal elements
|
| 263 |
+
to the left and below this submatrix must be deflated (i.e. zero).
|
| 264 |
+
following restriction is imposed: n1>=n0+2
|
| 265 |
+
A (input/output) On input, A is an upper Hessenberg matrix.
|
| 266 |
+
On output, A is replaced by "R Q + shift * 1"
|
| 267 |
+
Q (input/output) The parameter Q is multiplied by the unitary matrix
|
| 268 |
+
Q arising from the QR decomposition. Q can also be false, in which
|
| 269 |
+
case the unitary matrix Q is not computated.
|
| 270 |
+
shift (input) a complex number specifying the shift. idealy close to an
|
| 271 |
+
eigenvalue of the bottemmost part of the submatrix A[n0:n1,n0:n1].
|
| 272 |
+
|
| 273 |
+
references:
|
| 274 |
+
Stoer, Bulirsch - Introduction to Numerical Analysis.
|
| 275 |
+
Kresser : Numerical Methods for General and Structured Eigenvalue Problems
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
# implicitly shifted and bulge chasing is explained at p.398/399 in "Stoer, Bulirsch - Introduction to Numerical Analysis"
|
| 279 |
+
# for bulge chasing see also "Watkins - The Matrix Eigenvalue Problem" sec.4.5,p.173
|
| 280 |
+
|
| 281 |
+
# the Givens rotation we used is determined as follows: let c,s be two complex
|
| 282 |
+
# numbers. then we have following relation:
|
| 283 |
+
#
|
| 284 |
+
# v = sqrt(|c|^2 + |s|^2)
|
| 285 |
+
#
|
| 286 |
+
# 1/v [ c~ s~] [c] = [v]
|
| 287 |
+
# [-s c ] [s] [0]
|
| 288 |
+
#
|
| 289 |
+
# the matrix on the left is our Givens rotation.
|
| 290 |
+
|
| 291 |
+
n = A.rows
|
| 292 |
+
|
| 293 |
+
# first step
|
| 294 |
+
|
| 295 |
+
# calculate givens rotation
|
| 296 |
+
c = A[n0 ,n0] - shift
|
| 297 |
+
s = A[n0+1,n0]
|
| 298 |
+
|
| 299 |
+
v = ctx.hypot(ctx.hypot(ctx.re(c), ctx.im(c)), ctx.hypot(ctx.re(s), ctx.im(s)))
|
| 300 |
+
|
| 301 |
+
if v == 0:
|
| 302 |
+
v = 1
|
| 303 |
+
c = 1
|
| 304 |
+
s = 0
|
| 305 |
+
else:
|
| 306 |
+
c /= v
|
| 307 |
+
s /= v
|
| 308 |
+
|
| 309 |
+
cc = ctx.conj(c)
|
| 310 |
+
cs = ctx.conj(s)
|
| 311 |
+
|
| 312 |
+
for k in xrange(n0, n):
|
| 313 |
+
# apply givens rotation from the left
|
| 314 |
+
x = A[n0 ,k]
|
| 315 |
+
y = A[n0+1,k]
|
| 316 |
+
A[n0 ,k] = cc * x + cs * y
|
| 317 |
+
A[n0+1,k] = c * y - s * x
|
| 318 |
+
|
| 319 |
+
for k in xrange(min(n1, n0+3)):
|
| 320 |
+
# apply givens rotation from the right
|
| 321 |
+
x = A[k,n0 ]
|
| 322 |
+
y = A[k,n0+1]
|
| 323 |
+
A[k,n0 ] = c * x + s * y
|
| 324 |
+
A[k,n0+1] = cc * y - cs * x
|
| 325 |
+
|
| 326 |
+
if not isinstance(Q, bool):
|
| 327 |
+
for k in xrange(n):
|
| 328 |
+
# eigenvectors
|
| 329 |
+
x = Q[k,n0 ]
|
| 330 |
+
y = Q[k,n0+1]
|
| 331 |
+
Q[k,n0 ] = c * x + s * y
|
| 332 |
+
Q[k,n0+1] = cc * y - cs * x
|
| 333 |
+
|
| 334 |
+
# chase the bulge
|
| 335 |
+
|
| 336 |
+
for j in xrange(n0, n1 - 2):
|
| 337 |
+
# calculate givens rotation
|
| 338 |
+
|
| 339 |
+
c = A[j+1,j]
|
| 340 |
+
s = A[j+2,j]
|
| 341 |
+
|
| 342 |
+
v = ctx.hypot(ctx.hypot(ctx.re(c), ctx.im(c)), ctx.hypot(ctx.re(s), ctx.im(s)))
|
| 343 |
+
|
| 344 |
+
if v == 0:
|
| 345 |
+
A[j+1,j] = 0
|
| 346 |
+
v = 1
|
| 347 |
+
c = 1
|
| 348 |
+
s = 0
|
| 349 |
+
else:
|
| 350 |
+
A[j+1,j] = v
|
| 351 |
+
c /= v
|
| 352 |
+
s /= v
|
| 353 |
+
|
| 354 |
+
A[j+2,j] = 0
|
| 355 |
+
|
| 356 |
+
cc = ctx.conj(c)
|
| 357 |
+
cs = ctx.conj(s)
|
| 358 |
+
|
| 359 |
+
for k in xrange(j+1, n):
|
| 360 |
+
# apply givens rotation from the left
|
| 361 |
+
x = A[j+1,k]
|
| 362 |
+
y = A[j+2,k]
|
| 363 |
+
A[j+1,k] = cc * x + cs * y
|
| 364 |
+
A[j+2,k] = c * y - s * x
|
| 365 |
+
|
| 366 |
+
for k in xrange(0, min(n1, j+4)):
|
| 367 |
+
# apply givens rotation from the right
|
| 368 |
+
x = A[k,j+1]
|
| 369 |
+
y = A[k,j+2]
|
| 370 |
+
A[k,j+1] = c * x + s * y
|
| 371 |
+
A[k,j+2] = cc * y - cs * x
|
| 372 |
+
|
| 373 |
+
if not isinstance(Q, bool):
|
| 374 |
+
for k in xrange(0, n):
|
| 375 |
+
# eigenvectors
|
| 376 |
+
x = Q[k,j+1]
|
| 377 |
+
y = Q[k,j+2]
|
| 378 |
+
Q[k,j+1] = c * x + s * y
|
| 379 |
+
Q[k,j+2] = cc * y - cs * x
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def hessenberg_qr(ctx, A, Q):
|
| 384 |
+
"""
|
| 385 |
+
This routine computes the Schur decomposition of an upper Hessenberg matrix A.
|
| 386 |
+
Given A, an unitary matrix Q is determined such that
|
| 387 |
+
|
| 388 |
+
Q' A Q = R and Q' Q = Q Q' = 1
|
| 389 |
+
|
| 390 |
+
where R is an upper right triangular matrix. Here ' denotes the hermitian
|
| 391 |
+
transpose (i.e. transposition and conjugation).
|
| 392 |
+
|
| 393 |
+
parameters:
|
| 394 |
+
A (input/output) On input, A contains an upper Hessenberg matrix.
|
| 395 |
+
On output, A is replace by the upper right triangluar matrix R.
|
| 396 |
+
|
| 397 |
+
Q (input/output) The parameter Q is multiplied by the unitary
|
| 398 |
+
matrix Q arising from the Schur decomposition. Q can also be
|
| 399 |
+
false, in which case the unitary matrix Q is not computated.
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
n = A.rows
|
| 403 |
+
|
| 404 |
+
norm = 0
|
| 405 |
+
for x in xrange(n):
|
| 406 |
+
for y in xrange(min(x+2, n)):
|
| 407 |
+
norm += ctx.re(A[y,x]) ** 2 + ctx.im(A[y,x]) ** 2
|
| 408 |
+
norm = ctx.sqrt(norm) / n
|
| 409 |
+
|
| 410 |
+
if norm == 0:
|
| 411 |
+
return
|
| 412 |
+
|
| 413 |
+
n0 = 0
|
| 414 |
+
n1 = n
|
| 415 |
+
|
| 416 |
+
eps = ctx.eps / (100 * n)
|
| 417 |
+
maxits = ctx.dps * 4
|
| 418 |
+
|
| 419 |
+
its = totalits = 0
|
| 420 |
+
|
| 421 |
+
while 1:
|
| 422 |
+
# kressner p.32 algo 3
|
| 423 |
+
# the active submatrix is A[n0:n1,n0:n1]
|
| 424 |
+
|
| 425 |
+
k = n0
|
| 426 |
+
|
| 427 |
+
while k + 1 < n1:
|
| 428 |
+
s = abs(ctx.re(A[k,k])) + abs(ctx.im(A[k,k])) + abs(ctx.re(A[k+1,k+1])) + abs(ctx.im(A[k+1,k+1]))
|
| 429 |
+
if s < eps * norm:
|
| 430 |
+
s = norm
|
| 431 |
+
if abs(A[k+1,k]) < eps * s:
|
| 432 |
+
break
|
| 433 |
+
k += 1
|
| 434 |
+
|
| 435 |
+
if k + 1 < n1:
|
| 436 |
+
# deflation found at position (k+1, k)
|
| 437 |
+
|
| 438 |
+
A[k+1,k] = 0
|
| 439 |
+
n0 = k + 1
|
| 440 |
+
|
| 441 |
+
its = 0
|
| 442 |
+
|
| 443 |
+
if n0 + 1 >= n1:
|
| 444 |
+
# block of size at most two has converged
|
| 445 |
+
n0 = 0
|
| 446 |
+
n1 = k + 1
|
| 447 |
+
if n1 < 2:
|
| 448 |
+
# QR algorithm has converged
|
| 449 |
+
return
|
| 450 |
+
else:
|
| 451 |
+
if (its % 30) == 10:
|
| 452 |
+
# exceptional shift
|
| 453 |
+
shift = A[n1-1,n1-2]
|
| 454 |
+
elif (its % 30) == 20:
|
| 455 |
+
# exceptional shift
|
| 456 |
+
shift = abs(A[n1-1,n1-2])
|
| 457 |
+
elif (its % 30) == 29:
|
| 458 |
+
# exceptional shift
|
| 459 |
+
shift = norm
|
| 460 |
+
else:
|
| 461 |
+
# A = [ a b ] det(x-A)=x*x-x*tr(A)+det(A)
|
| 462 |
+
# [ c d ]
|
| 463 |
+
#
|
| 464 |
+
# eigenvalues bad: (tr(A)+sqrt((tr(A))**2-4*det(A)))/2
|
| 465 |
+
# bad because of cancellation if |c| is small and |a-d| is small, too.
|
| 466 |
+
#
|
| 467 |
+
# eigenvalues good: (a+d+sqrt((a-d)**2+4*b*c))/2
|
| 468 |
+
|
| 469 |
+
t = A[n1-2,n1-2] + A[n1-1,n1-1]
|
| 470 |
+
s = (A[n1-1,n1-1] - A[n1-2,n1-2]) ** 2 + 4 * A[n1-1,n1-2] * A[n1-2,n1-1]
|
| 471 |
+
if ctx.re(s) > 0:
|
| 472 |
+
s = ctx.sqrt(s)
|
| 473 |
+
else:
|
| 474 |
+
s = ctx.sqrt(-s) * 1j
|
| 475 |
+
a = (t + s) / 2
|
| 476 |
+
b = (t - s) / 2
|
| 477 |
+
if abs(A[n1-1,n1-1] - a) > abs(A[n1-1,n1-1] - b):
|
| 478 |
+
shift = b
|
| 479 |
+
else:
|
| 480 |
+
shift = a
|
| 481 |
+
|
| 482 |
+
its += 1
|
| 483 |
+
totalits += 1
|
| 484 |
+
|
| 485 |
+
qr_step(ctx, n0, n1, A, Q, shift)
|
| 486 |
+
|
| 487 |
+
if its > maxits:
|
| 488 |
+
raise RuntimeError("qr: failed to converge after %d steps" % its)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
@defun
|
| 492 |
+
def schur(ctx, A, overwrite_a = False):
|
| 493 |
+
"""
|
| 494 |
+
This routine computes the Schur decomposition of a square matrix A.
|
| 495 |
+
Given A, an unitary matrix Q is determined such that
|
| 496 |
+
|
| 497 |
+
Q' A Q = R and Q' Q = Q Q' = 1
|
| 498 |
+
|
| 499 |
+
where R is an upper right triangular matrix. Here ' denotes the
|
| 500 |
+
hermitian transpose (i.e. transposition and conjugation).
|
| 501 |
+
|
| 502 |
+
input:
|
| 503 |
+
A : a real or complex square matrix
|
| 504 |
+
overwrite_a : if true, allows modification of A which may improve
|
| 505 |
+
performance. if false, A is not modified.
|
| 506 |
+
|
| 507 |
+
output:
|
| 508 |
+
Q : an unitary matrix
|
| 509 |
+
R : an upper right triangular matrix
|
| 510 |
+
|
| 511 |
+
return value: (Q, R)
|
| 512 |
+
|
| 513 |
+
example:
|
| 514 |
+
>>> from mpmath import mp
|
| 515 |
+
>>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
|
| 516 |
+
>>> Q, R = mp.schur(A)
|
| 517 |
+
>>> mp.nprint(R, 3) # doctest:+SKIP
|
| 518 |
+
[2.0 0.417 -2.53]
|
| 519 |
+
[0.0 4.0 -4.74]
|
| 520 |
+
[0.0 0.0 9.0]
|
| 521 |
+
>>> print(mp.chop(A - Q * R * Q.transpose_conj()))
|
| 522 |
+
[0.0 0.0 0.0]
|
| 523 |
+
[0.0 0.0 0.0]
|
| 524 |
+
[0.0 0.0 0.0]
|
| 525 |
+
|
| 526 |
+
warning: The Schur decomposition is not unique.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
n = A.rows
|
| 530 |
+
|
| 531 |
+
if n == 1:
|
| 532 |
+
return (ctx.matrix([[1]]), A)
|
| 533 |
+
|
| 534 |
+
if not overwrite_a:
|
| 535 |
+
A = A.copy()
|
| 536 |
+
|
| 537 |
+
T = ctx.matrix(n, 1)
|
| 538 |
+
|
| 539 |
+
hessenberg_reduce_0(ctx, A, T)
|
| 540 |
+
Q = A.copy()
|
| 541 |
+
hessenberg_reduce_1(ctx, Q, T)
|
| 542 |
+
|
| 543 |
+
for x in xrange(n):
|
| 544 |
+
for y in xrange(x + 2, n):
|
| 545 |
+
A[y,x] = 0
|
| 546 |
+
|
| 547 |
+
hessenberg_qr(ctx, A, Q)
|
| 548 |
+
|
| 549 |
+
return Q, A
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def eig_tr_r(ctx, A):
|
| 553 |
+
"""
|
| 554 |
+
This routine calculates the right eigenvectors of an upper right triangular matrix.
|
| 555 |
+
|
| 556 |
+
input:
|
| 557 |
+
A an upper right triangular matrix
|
| 558 |
+
|
| 559 |
+
output:
|
| 560 |
+
ER a matrix whose columns form the right eigenvectors of A
|
| 561 |
+
|
| 562 |
+
return value: ER
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
# this subroutine is inspired by the lapack routines ctrevc.f,clatrs.f
|
| 566 |
+
|
| 567 |
+
n = A.rows
|
| 568 |
+
|
| 569 |
+
ER = ctx.eye(n)
|
| 570 |
+
|
| 571 |
+
eps = ctx.eps
|
| 572 |
+
|
| 573 |
+
unfl = ctx.ldexp(ctx.one, -ctx.prec * 30)
|
| 574 |
+
# since mpmath effectively has no limits on the exponent, we simply scale doubles up
|
| 575 |
+
# original double has prec*20
|
| 576 |
+
|
| 577 |
+
smlnum = unfl * (n / eps)
|
| 578 |
+
simin = 1 / ctx.sqrt(eps)
|
| 579 |
+
|
| 580 |
+
rmax = 1
|
| 581 |
+
|
| 582 |
+
for i in xrange(1, n):
|
| 583 |
+
s = A[i,i]
|
| 584 |
+
|
| 585 |
+
smin = max(eps * abs(s), smlnum)
|
| 586 |
+
|
| 587 |
+
for j in xrange(i - 1, -1, -1):
|
| 588 |
+
|
| 589 |
+
r = 0
|
| 590 |
+
for k in xrange(j + 1, i + 1):
|
| 591 |
+
r += A[j,k] * ER[k,i]
|
| 592 |
+
|
| 593 |
+
t = A[j,j] - s
|
| 594 |
+
if abs(t) < smin:
|
| 595 |
+
t = smin
|
| 596 |
+
|
| 597 |
+
r = -r / t
|
| 598 |
+
ER[j,i] = r
|
| 599 |
+
|
| 600 |
+
rmax = max(rmax, abs(r))
|
| 601 |
+
if rmax > simin:
|
| 602 |
+
for k in xrange(j, i+1):
|
| 603 |
+
ER[k,i] /= rmax
|
| 604 |
+
rmax = 1
|
| 605 |
+
|
| 606 |
+
if rmax != 1:
|
| 607 |
+
for k in xrange(0, i + 1):
|
| 608 |
+
ER[k,i] /= rmax
|
| 609 |
+
|
| 610 |
+
return ER
|
| 611 |
+
|
| 612 |
+
def eig_tr_l(ctx, A):
|
| 613 |
+
"""
|
| 614 |
+
This routine calculates the left eigenvectors of an upper right triangular matrix.
|
| 615 |
+
|
| 616 |
+
input:
|
| 617 |
+
A an upper right triangular matrix
|
| 618 |
+
|
| 619 |
+
output:
|
| 620 |
+
EL a matrix whose rows form the left eigenvectors of A
|
| 621 |
+
|
| 622 |
+
return value: EL
|
| 623 |
+
"""
|
| 624 |
+
|
| 625 |
+
n = A.rows
|
| 626 |
+
|
| 627 |
+
EL = ctx.eye(n)
|
| 628 |
+
|
| 629 |
+
eps = ctx.eps
|
| 630 |
+
|
| 631 |
+
unfl = ctx.ldexp(ctx.one, -ctx.prec * 30)
|
| 632 |
+
# since mpmath effectively has no limits on the exponent, we simply scale doubles up
|
| 633 |
+
# original double has prec*20
|
| 634 |
+
|
| 635 |
+
smlnum = unfl * (n / eps)
|
| 636 |
+
simin = 1 / ctx.sqrt(eps)
|
| 637 |
+
|
| 638 |
+
rmax = 1
|
| 639 |
+
|
| 640 |
+
for i in xrange(0, n - 1):
|
| 641 |
+
s = A[i,i]
|
| 642 |
+
|
| 643 |
+
smin = max(eps * abs(s), smlnum)
|
| 644 |
+
|
| 645 |
+
for j in xrange(i + 1, n):
|
| 646 |
+
|
| 647 |
+
r = 0
|
| 648 |
+
for k in xrange(i, j):
|
| 649 |
+
r += EL[i,k] * A[k,j]
|
| 650 |
+
|
| 651 |
+
t = A[j,j] - s
|
| 652 |
+
if abs(t) < smin:
|
| 653 |
+
t = smin
|
| 654 |
+
|
| 655 |
+
r = -r / t
|
| 656 |
+
EL[i,j] = r
|
| 657 |
+
|
| 658 |
+
rmax = max(rmax, abs(r))
|
| 659 |
+
if rmax > simin:
|
| 660 |
+
for k in xrange(i, j + 1):
|
| 661 |
+
EL[i,k] /= rmax
|
| 662 |
+
rmax = 1
|
| 663 |
+
|
| 664 |
+
if rmax != 1:
|
| 665 |
+
for k in xrange(i, n):
|
| 666 |
+
EL[i,k] /= rmax
|
| 667 |
+
|
| 668 |
+
return EL
|
| 669 |
+
|
| 670 |
+
@defun
|
| 671 |
+
def eig(ctx, A, left = False, right = True, overwrite_a = False):
|
| 672 |
+
"""
|
| 673 |
+
This routine computes the eigenvalues and optionally the left and right
|
| 674 |
+
eigenvectors of a square matrix A. Given A, a vector E and matrices ER
|
| 675 |
+
and EL are calculated such that
|
| 676 |
+
|
| 677 |
+
A ER[:,i] = E[i] ER[:,i]
|
| 678 |
+
EL[i,:] A = EL[i,:] E[i]
|
| 679 |
+
|
| 680 |
+
E contains the eigenvalues of A. The columns of ER contain the right eigenvectors
|
| 681 |
+
of A whereas the rows of EL contain the left eigenvectors.
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
input:
|
| 685 |
+
A : a real or complex square matrix of shape (n, n)
|
| 686 |
+
left : if true, the left eigenvectors are calculated.
|
| 687 |
+
right : if true, the right eigenvectors are calculated.
|
| 688 |
+
overwrite_a : if true, allows modification of A which may improve
|
| 689 |
+
performance. if false, A is not modified.
|
| 690 |
+
|
| 691 |
+
output:
|
| 692 |
+
E : a list of length n containing the eigenvalues of A.
|
| 693 |
+
ER : a matrix whose columns contain the right eigenvectors of A.
|
| 694 |
+
EL : a matrix whose rows contain the left eigenvectors of A.
|
| 695 |
+
|
| 696 |
+
return values:
|
| 697 |
+
E if left and right are both false.
|
| 698 |
+
(E, ER) if right is true and left is false.
|
| 699 |
+
(E, EL) if left is true and right is false.
|
| 700 |
+
(E, EL, ER) if left and right are true.
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
examples:
|
| 704 |
+
>>> from mpmath import mp
|
| 705 |
+
>>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
|
| 706 |
+
>>> E, ER = mp.eig(A)
|
| 707 |
+
>>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
|
| 708 |
+
[0.0]
|
| 709 |
+
[0.0]
|
| 710 |
+
[0.0]
|
| 711 |
+
|
| 712 |
+
>>> E, EL, ER = mp.eig(A,left = True, right = True)
|
| 713 |
+
>>> E, EL, ER = mp.eig_sort(E, EL, ER)
|
| 714 |
+
>>> mp.nprint(E)
|
| 715 |
+
[2.0, 4.0, 9.0]
|
| 716 |
+
>>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
|
| 717 |
+
[0.0]
|
| 718 |
+
[0.0]
|
| 719 |
+
[0.0]
|
| 720 |
+
>>> print(mp.chop( EL[0,:] * A - EL[0,:] * E[0]))
|
| 721 |
+
[0.0 0.0 0.0]
|
| 722 |
+
|
| 723 |
+
warning:
|
| 724 |
+
- If there are multiple eigenvalues, the eigenvectors do not necessarily
|
| 725 |
+
span the whole vectorspace, i.e. ER and EL may have not full rank.
|
| 726 |
+
Furthermore in that case the eigenvectors are numerical ill-conditioned.
|
| 727 |
+
- In the general case the eigenvalues have no natural order.
|
| 728 |
+
|
| 729 |
+
see also:
|
| 730 |
+
- eigh (or eigsy, eighe) for the symmetric eigenvalue problem.
|
| 731 |
+
- eig_sort for sorting of eigenvalues and eigenvectors
|
| 732 |
+
"""
|
| 733 |
+
|
| 734 |
+
n = A.rows
|
| 735 |
+
|
| 736 |
+
if n == 1:
|
| 737 |
+
if left and (not right):
|
| 738 |
+
return ([A[0]], ctx.matrix([[1]]))
|
| 739 |
+
|
| 740 |
+
if right and (not left):
|
| 741 |
+
return ([A[0]], ctx.matrix([[1]]))
|
| 742 |
+
|
| 743 |
+
return ([A[0]], ctx.matrix([[1]]), ctx.matrix([[1]]))
|
| 744 |
+
|
| 745 |
+
if not overwrite_a:
|
| 746 |
+
A = A.copy()
|
| 747 |
+
|
| 748 |
+
T = ctx.zeros(n, 1)
|
| 749 |
+
|
| 750 |
+
hessenberg_reduce_0(ctx, A, T)
|
| 751 |
+
|
| 752 |
+
if left or right:
|
| 753 |
+
Q = A.copy()
|
| 754 |
+
hessenberg_reduce_1(ctx, Q, T)
|
| 755 |
+
else:
|
| 756 |
+
Q = False
|
| 757 |
+
|
| 758 |
+
for x in xrange(n):
|
| 759 |
+
for y in xrange(x + 2, n):
|
| 760 |
+
A[y,x] = 0
|
| 761 |
+
|
| 762 |
+
hessenberg_qr(ctx, A, Q)
|
| 763 |
+
|
| 764 |
+
E = [0 for i in xrange(n)]
|
| 765 |
+
for i in xrange(n):
|
| 766 |
+
E[i] = A[i,i]
|
| 767 |
+
|
| 768 |
+
if not (left or right):
|
| 769 |
+
return E
|
| 770 |
+
|
| 771 |
+
if left:
|
| 772 |
+
EL = eig_tr_l(ctx, A)
|
| 773 |
+
EL = EL * Q.transpose_conj()
|
| 774 |
+
|
| 775 |
+
if right:
|
| 776 |
+
ER = eig_tr_r(ctx, A)
|
| 777 |
+
ER = Q * ER
|
| 778 |
+
|
| 779 |
+
if left and (not right):
|
| 780 |
+
return (E, EL)
|
| 781 |
+
|
| 782 |
+
if right and (not left):
|
| 783 |
+
return (E, ER)
|
| 784 |
+
|
| 785 |
+
return (E, EL, ER)
|
| 786 |
+
|
| 787 |
+
@defun
|
| 788 |
+
def eig_sort(ctx, E, EL = False, ER = False, f = "real"):
|
| 789 |
+
"""
|
| 790 |
+
This routine sorts the eigenvalues and eigenvectors delivered by ``eig``.
|
| 791 |
+
|
| 792 |
+
parameters:
|
| 793 |
+
E : the eigenvalues as delivered by eig
|
| 794 |
+
EL : the left eigenvectors as delivered by eig, or false
|
| 795 |
+
ER : the right eigenvectors as delivered by eig, or false
|
| 796 |
+
f : either a string ("real" sort by increasing real part, "imag" sort by
|
| 797 |
+
increasing imag part, "abs" sort by absolute value) or a function
|
| 798 |
+
mapping complexs to the reals, i.e. ``f = lambda x: -mp.re(x) ``
|
| 799 |
+
would sort the eigenvalues by decreasing real part.
|
| 800 |
+
|
| 801 |
+
return values:
|
| 802 |
+
E if EL and ER are both false.
|
| 803 |
+
(E, ER) if ER is not false and left is false.
|
| 804 |
+
(E, EL) if EL is not false and right is false.
|
| 805 |
+
(E, EL, ER) if EL and ER are not false.
|
| 806 |
+
|
| 807 |
+
example:
|
| 808 |
+
>>> from mpmath import mp
|
| 809 |
+
>>> A = mp.matrix([[3, -1, 2], [2, 5, -5], [-2, -3, 7]])
|
| 810 |
+
>>> E, EL, ER = mp.eig(A,left = True, right = True)
|
| 811 |
+
>>> E, EL, ER = mp.eig_sort(E, EL, ER)
|
| 812 |
+
>>> mp.nprint(E)
|
| 813 |
+
[2.0, 4.0, 9.0]
|
| 814 |
+
>>> E, EL, ER = mp.eig_sort(E, EL, ER,f = lambda x: -mp.re(x))
|
| 815 |
+
>>> mp.nprint(E)
|
| 816 |
+
[9.0, 4.0, 2.0]
|
| 817 |
+
>>> print(mp.chop(A * ER[:,0] - E[0] * ER[:,0]))
|
| 818 |
+
[0.0]
|
| 819 |
+
[0.0]
|
| 820 |
+
[0.0]
|
| 821 |
+
>>> print(mp.chop( EL[0,:] * A - EL[0,:] * E[0]))
|
| 822 |
+
[0.0 0.0 0.0]
|
| 823 |
+
"""
|
| 824 |
+
|
| 825 |
+
if isinstance(f, str):
|
| 826 |
+
if f == "real":
|
| 827 |
+
f = ctx.re
|
| 828 |
+
elif f == "imag":
|
| 829 |
+
f = ctx.im
|
| 830 |
+
elif f == "abs":
|
| 831 |
+
f = abs
|
| 832 |
+
else:
|
| 833 |
+
raise RuntimeError("unknown function %s" % f)
|
| 834 |
+
|
| 835 |
+
n = len(E)
|
| 836 |
+
|
| 837 |
+
# Sort eigenvalues (bubble-sort)
|
| 838 |
+
|
| 839 |
+
for i in xrange(n):
|
| 840 |
+
imax = i
|
| 841 |
+
s = f(E[i]) # s is the current maximal element
|
| 842 |
+
|
| 843 |
+
for j in xrange(i + 1, n):
|
| 844 |
+
c = f(E[j])
|
| 845 |
+
if c < s:
|
| 846 |
+
s = c
|
| 847 |
+
imax = j
|
| 848 |
+
|
| 849 |
+
if imax != i:
|
| 850 |
+
# swap eigenvalues
|
| 851 |
+
|
| 852 |
+
z = E[i]
|
| 853 |
+
E[i] = E[imax]
|
| 854 |
+
E[imax] = z
|
| 855 |
+
|
| 856 |
+
if not isinstance(EL, bool):
|
| 857 |
+
for j in xrange(n):
|
| 858 |
+
z = EL[i,j]
|
| 859 |
+
EL[i,j] = EL[imax,j]
|
| 860 |
+
EL[imax,j] = z
|
| 861 |
+
|
| 862 |
+
if not isinstance(ER, bool):
|
| 863 |
+
for j in xrange(n):
|
| 864 |
+
z = ER[j,i]
|
| 865 |
+
ER[j,i] = ER[j,imax]
|
| 866 |
+
ER[j,imax] = z
|
| 867 |
+
|
| 868 |
+
if isinstance(EL, bool) and isinstance(ER, bool):
|
| 869 |
+
return E
|
| 870 |
+
|
| 871 |
+
if isinstance(EL, bool) and not(isinstance(ER, bool)):
|
| 872 |
+
return (E, ER)
|
| 873 |
+
|
| 874 |
+
if isinstance(ER, bool) and not(isinstance(EL, bool)):
|
| 875 |
+
return (E, EL)
|
| 876 |
+
|
| 877 |
+
return (E, EL, ER)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/linalg.py
ADDED
|
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Linear algebra
|
| 3 |
+
--------------
|
| 4 |
+
|
| 5 |
+
Linear equations
|
| 6 |
+
................
|
| 7 |
+
|
| 8 |
+
Basic linear algebra is implemented; you can for example solve the linear
|
| 9 |
+
equation system::
|
| 10 |
+
|
| 11 |
+
x + 2*y = -10
|
| 12 |
+
3*x + 4*y = 10
|
| 13 |
+
|
| 14 |
+
using ``lu_solve``::
|
| 15 |
+
|
| 16 |
+
>>> from mpmath import *
|
| 17 |
+
>>> mp.pretty = False
|
| 18 |
+
>>> A = matrix([[1, 2], [3, 4]])
|
| 19 |
+
>>> b = matrix([-10, 10])
|
| 20 |
+
>>> x = lu_solve(A, b)
|
| 21 |
+
>>> x
|
| 22 |
+
matrix(
|
| 23 |
+
[['30.0'],
|
| 24 |
+
['-20.0']])
|
| 25 |
+
|
| 26 |
+
If you don't trust the result, use ``residual`` to calculate the residual ||A*x-b||::
|
| 27 |
+
|
| 28 |
+
>>> residual(A, x, b)
|
| 29 |
+
matrix(
|
| 30 |
+
[['3.46944695195361e-18'],
|
| 31 |
+
['3.46944695195361e-18']])
|
| 32 |
+
>>> str(eps)
|
| 33 |
+
'2.22044604925031e-16'
|
| 34 |
+
|
| 35 |
+
As you can see, the solution is quite accurate. The error is caused by the
|
| 36 |
+
inaccuracy of the internal floating point arithmetic. Though, it's even smaller
|
| 37 |
+
than the current machine epsilon, which basically means you can trust the
|
| 38 |
+
result.
|
| 39 |
+
|
| 40 |
+
If you need more speed, use NumPy, or ``fp.lu_solve`` for a floating-point computation.
|
| 41 |
+
|
| 42 |
+
>>> fp.lu_solve(A, b) # doctest: +ELLIPSIS
|
| 43 |
+
matrix(...)
|
| 44 |
+
|
| 45 |
+
``lu_solve`` accepts overdetermined systems. It is usually not possible to solve
|
| 46 |
+
such systems, so the residual is minimized instead. Internally this is done
|
| 47 |
+
using Cholesky decomposition to compute a least squares approximation. This means
|
| 48 |
+
that that ``lu_solve`` will square the errors. If you can't afford this, use
|
| 49 |
+
``qr_solve`` instead. It is twice as slow but more accurate, and it calculates
|
| 50 |
+
the residual automatically.
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Matrix factorization
|
| 54 |
+
....................
|
| 55 |
+
|
| 56 |
+
The function ``lu`` computes an explicit LU factorization of a matrix::
|
| 57 |
+
|
| 58 |
+
>>> P, L, U = lu(matrix([[0,2,3],[4,5,6],[7,8,9]]))
|
| 59 |
+
>>> print(P)
|
| 60 |
+
[0.0 0.0 1.0]
|
| 61 |
+
[1.0 0.0 0.0]
|
| 62 |
+
[0.0 1.0 0.0]
|
| 63 |
+
>>> print(L)
|
| 64 |
+
[ 1.0 0.0 0.0]
|
| 65 |
+
[ 0.0 1.0 0.0]
|
| 66 |
+
[0.571428571428571 0.214285714285714 1.0]
|
| 67 |
+
>>> print(U)
|
| 68 |
+
[7.0 8.0 9.0]
|
| 69 |
+
[0.0 2.0 3.0]
|
| 70 |
+
[0.0 0.0 0.214285714285714]
|
| 71 |
+
>>> print(P.T*L*U)
|
| 72 |
+
[0.0 2.0 3.0]
|
| 73 |
+
[4.0 5.0 6.0]
|
| 74 |
+
[7.0 8.0 9.0]
|
| 75 |
+
|
| 76 |
+
Interval matrices
|
| 77 |
+
-----------------
|
| 78 |
+
|
| 79 |
+
Matrices may contain interval elements. This allows one to perform
|
| 80 |
+
basic linear algebra operations such as matrix multiplication
|
| 81 |
+
and equation solving with rigorous error bounds::
|
| 82 |
+
|
| 83 |
+
>>> a = iv.matrix([['0.1','0.3','1.0'],
|
| 84 |
+
... ['7.1','5.5','4.8'],
|
| 85 |
+
... ['3.2','4.4','5.6']])
|
| 86 |
+
>>>
|
| 87 |
+
>>> b = iv.matrix(['4','0.6','0.5'])
|
| 88 |
+
>>> c = iv.lu_solve(a, b)
|
| 89 |
+
>>> print(c)
|
| 90 |
+
[ [5.2582327113062568605927528666, 5.25823271130625686059275702219]]
|
| 91 |
+
[[-13.1550493962678375411635581388, -13.1550493962678375411635540152]]
|
| 92 |
+
[ [7.42069154774972557628979076189, 7.42069154774972557628979190734]]
|
| 93 |
+
>>> print(a*c)
|
| 94 |
+
[ [3.99999999999999999999999844904, 4.00000000000000000000000155096]]
|
| 95 |
+
[[0.599999999999999999999968898009, 0.600000000000000000000031763736]]
|
| 96 |
+
[[0.499999999999999999999979320485, 0.500000000000000000000020679515]]
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
# TODO:
|
| 100 |
+
# *implement high-level qr()
|
| 101 |
+
# *test unitvector
|
| 102 |
+
# *iterative solving
|
| 103 |
+
|
| 104 |
+
from copy import copy
|
| 105 |
+
|
| 106 |
+
from ..libmp.backend import xrange
|
| 107 |
+
|
| 108 |
+
class LinearAlgebraMethods(object):
|
| 109 |
+
|
| 110 |
+
def LU_decomp(ctx, A, overwrite=False, use_cache=True):
|
| 111 |
+
"""
|
| 112 |
+
LU-factorization of a n*n matrix using the Gauss algorithm.
|
| 113 |
+
Returns L and U in one matrix and the pivot indices.
|
| 114 |
+
|
| 115 |
+
Use overwrite to specify whether A will be overwritten with L and U.
|
| 116 |
+
"""
|
| 117 |
+
if not A.rows == A.cols:
|
| 118 |
+
raise ValueError('need n*n matrix')
|
| 119 |
+
# get from cache if possible
|
| 120 |
+
if use_cache and isinstance(A, ctx.matrix) and A._LU:
|
| 121 |
+
return A._LU
|
| 122 |
+
if not overwrite:
|
| 123 |
+
orig = A
|
| 124 |
+
A = A.copy()
|
| 125 |
+
tol = ctx.absmin(ctx.mnorm(A,1) * ctx.eps) # each pivot element has to be bigger
|
| 126 |
+
n = A.rows
|
| 127 |
+
p = [None]*(n - 1)
|
| 128 |
+
for j in xrange(n - 1):
|
| 129 |
+
# pivoting, choose max(abs(reciprocal row sum)*abs(pivot element))
|
| 130 |
+
biggest = 0
|
| 131 |
+
for k in xrange(j, n):
|
| 132 |
+
s = ctx.fsum([ctx.absmin(A[k,l]) for l in xrange(j, n)])
|
| 133 |
+
if ctx.absmin(s) <= tol:
|
| 134 |
+
raise ZeroDivisionError('matrix is numerically singular')
|
| 135 |
+
current = 1/s * ctx.absmin(A[k,j])
|
| 136 |
+
if current > biggest: # TODO: what if equal?
|
| 137 |
+
biggest = current
|
| 138 |
+
p[j] = k
|
| 139 |
+
# swap rows according to p
|
| 140 |
+
ctx.swap_row(A, j, p[j])
|
| 141 |
+
if ctx.absmin(A[j,j]) <= tol:
|
| 142 |
+
raise ZeroDivisionError('matrix is numerically singular')
|
| 143 |
+
# calculate elimination factors and add rows
|
| 144 |
+
for i in xrange(j + 1, n):
|
| 145 |
+
A[i,j] /= A[j,j]
|
| 146 |
+
for k in xrange(j + 1, n):
|
| 147 |
+
A[i,k] -= A[i,j]*A[j,k]
|
| 148 |
+
if ctx.absmin(A[n - 1,n - 1]) <= tol:
|
| 149 |
+
raise ZeroDivisionError('matrix is numerically singular')
|
| 150 |
+
# cache decomposition
|
| 151 |
+
if not overwrite and isinstance(orig, ctx.matrix):
|
| 152 |
+
orig._LU = (A, p)
|
| 153 |
+
return A, p
|
| 154 |
+
|
| 155 |
+
def L_solve(ctx, L, b, p=None):
|
| 156 |
+
"""
|
| 157 |
+
Solve the lower part of a LU factorized matrix for y.
|
| 158 |
+
"""
|
| 159 |
+
if L.rows != L.cols:
|
| 160 |
+
raise RuntimeError("need n*n matrix")
|
| 161 |
+
n = L.rows
|
| 162 |
+
if len(b) != n:
|
| 163 |
+
raise ValueError("Value should be equal to n")
|
| 164 |
+
b = copy(b)
|
| 165 |
+
if p: # swap b according to p
|
| 166 |
+
for k in xrange(0, len(p)):
|
| 167 |
+
ctx.swap_row(b, k, p[k])
|
| 168 |
+
# solve
|
| 169 |
+
for i in xrange(1, n):
|
| 170 |
+
for j in xrange(i):
|
| 171 |
+
b[i] -= L[i,j] * b[j]
|
| 172 |
+
return b
|
| 173 |
+
|
| 174 |
+
def U_solve(ctx, U, y):
|
| 175 |
+
"""
|
| 176 |
+
Solve the upper part of a LU factorized matrix for x.
|
| 177 |
+
"""
|
| 178 |
+
if U.rows != U.cols:
|
| 179 |
+
raise RuntimeError("need n*n matrix")
|
| 180 |
+
n = U.rows
|
| 181 |
+
if len(y) != n:
|
| 182 |
+
raise ValueError("Value should be equal to n")
|
| 183 |
+
x = copy(y)
|
| 184 |
+
for i in xrange(n - 1, -1, -1):
|
| 185 |
+
for j in xrange(i + 1, n):
|
| 186 |
+
x[i] -= U[i,j] * x[j]
|
| 187 |
+
x[i] /= U[i,i]
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
def lu_solve(ctx, A, b, **kwargs):
|
| 191 |
+
"""
|
| 192 |
+
Ax = b => x
|
| 193 |
+
|
| 194 |
+
Solve a determined or overdetermined linear equations system.
|
| 195 |
+
Fast LU decomposition is used, which is less accurate than QR decomposition
|
| 196 |
+
(especially for overdetermined systems), but it's twice as efficient.
|
| 197 |
+
Use qr_solve if you want more precision or have to solve a very ill-
|
| 198 |
+
conditioned system.
|
| 199 |
+
|
| 200 |
+
If you specify real=True, it does not check for overdeterminded complex
|
| 201 |
+
systems.
|
| 202 |
+
"""
|
| 203 |
+
prec = ctx.prec
|
| 204 |
+
try:
|
| 205 |
+
ctx.prec += 10
|
| 206 |
+
# do not overwrite A nor b
|
| 207 |
+
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
| 208 |
+
if A.rows < A.cols:
|
| 209 |
+
raise ValueError('cannot solve underdetermined system')
|
| 210 |
+
if A.rows > A.cols:
|
| 211 |
+
# use least-squares method if overdetermined
|
| 212 |
+
# (this increases errors)
|
| 213 |
+
AH = A.H
|
| 214 |
+
A = AH * A
|
| 215 |
+
b = AH * b
|
| 216 |
+
if (kwargs.get('real', False) or
|
| 217 |
+
not sum(type(i) is ctx.mpc for i in A)):
|
| 218 |
+
# TODO: necessary to check also b?
|
| 219 |
+
x = ctx.cholesky_solve(A, b)
|
| 220 |
+
else:
|
| 221 |
+
x = ctx.lu_solve(A, b)
|
| 222 |
+
else:
|
| 223 |
+
# LU factorization
|
| 224 |
+
A, p = ctx.LU_decomp(A)
|
| 225 |
+
b = ctx.L_solve(A, b, p)
|
| 226 |
+
x = ctx.U_solve(A, b)
|
| 227 |
+
finally:
|
| 228 |
+
ctx.prec = prec
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
def improve_solution(ctx, A, x, b, maxsteps=1):
|
| 232 |
+
"""
|
| 233 |
+
Improve a solution to a linear equation system iteratively.
|
| 234 |
+
|
| 235 |
+
This re-uses the LU decomposition and is thus cheap.
|
| 236 |
+
Usually 3 up to 4 iterations are giving the maximal improvement.
|
| 237 |
+
"""
|
| 238 |
+
if A.rows != A.cols:
|
| 239 |
+
raise RuntimeError("need n*n matrix") # TODO: really?
|
| 240 |
+
for _ in xrange(maxsteps):
|
| 241 |
+
r = ctx.residual(A, x, b)
|
| 242 |
+
if ctx.norm(r, 2) < 10*ctx.eps:
|
| 243 |
+
break
|
| 244 |
+
# this uses cached LU decomposition and is thus cheap
|
| 245 |
+
dx = ctx.lu_solve(A, -r)
|
| 246 |
+
x += dx
|
| 247 |
+
return x
|
| 248 |
+
|
| 249 |
+
def lu(ctx, A):
|
| 250 |
+
"""
|
| 251 |
+
A -> P, L, U
|
| 252 |
+
|
| 253 |
+
LU factorisation of a square matrix A. L is the lower, U the upper part.
|
| 254 |
+
P is the permutation matrix indicating the row swaps.
|
| 255 |
+
|
| 256 |
+
P*A = L*U
|
| 257 |
+
|
| 258 |
+
If you need efficiency, use the low-level method LU_decomp instead, it's
|
| 259 |
+
much more memory efficient.
|
| 260 |
+
"""
|
| 261 |
+
# get factorization
|
| 262 |
+
A, p = ctx.LU_decomp(A)
|
| 263 |
+
n = A.rows
|
| 264 |
+
L = ctx.matrix(n)
|
| 265 |
+
U = ctx.matrix(n)
|
| 266 |
+
for i in xrange(n):
|
| 267 |
+
for j in xrange(n):
|
| 268 |
+
if i > j:
|
| 269 |
+
L[i,j] = A[i,j]
|
| 270 |
+
elif i == j:
|
| 271 |
+
L[i,j] = 1
|
| 272 |
+
U[i,j] = A[i,j]
|
| 273 |
+
else:
|
| 274 |
+
U[i,j] = A[i,j]
|
| 275 |
+
# calculate permutation matrix
|
| 276 |
+
P = ctx.eye(n)
|
| 277 |
+
for k in xrange(len(p)):
|
| 278 |
+
ctx.swap_row(P, k, p[k])
|
| 279 |
+
return P, L, U
|
| 280 |
+
|
| 281 |
+
def unitvector(ctx, n, i):
|
| 282 |
+
"""
|
| 283 |
+
Return the i-th n-dimensional unit vector.
|
| 284 |
+
"""
|
| 285 |
+
assert 0 < i <= n, 'this unit vector does not exist'
|
| 286 |
+
return [ctx.zero]*(i-1) + [ctx.one] + [ctx.zero]*(n-i)
|
| 287 |
+
|
| 288 |
+
def inverse(ctx, A, **kwargs):
|
| 289 |
+
"""
|
| 290 |
+
Calculate the inverse of a matrix.
|
| 291 |
+
|
| 292 |
+
If you want to solve an equation system Ax = b, it's recommended to use
|
| 293 |
+
solve(A, b) instead, it's about 3 times more efficient.
|
| 294 |
+
"""
|
| 295 |
+
prec = ctx.prec
|
| 296 |
+
try:
|
| 297 |
+
ctx.prec += 10
|
| 298 |
+
# do not overwrite A
|
| 299 |
+
A = ctx.matrix(A, **kwargs).copy()
|
| 300 |
+
n = A.rows
|
| 301 |
+
# get LU factorisation
|
| 302 |
+
A, p = ctx.LU_decomp(A)
|
| 303 |
+
cols = []
|
| 304 |
+
# calculate unit vectors and solve corresponding system to get columns
|
| 305 |
+
for i in xrange(1, n + 1):
|
| 306 |
+
e = ctx.unitvector(n, i)
|
| 307 |
+
y = ctx.L_solve(A, e, p)
|
| 308 |
+
cols.append(ctx.U_solve(A, y))
|
| 309 |
+
# convert columns to matrix
|
| 310 |
+
inv = []
|
| 311 |
+
for i in xrange(n):
|
| 312 |
+
row = []
|
| 313 |
+
for j in xrange(n):
|
| 314 |
+
row.append(cols[j][i])
|
| 315 |
+
inv.append(row)
|
| 316 |
+
result = ctx.matrix(inv, **kwargs)
|
| 317 |
+
finally:
|
| 318 |
+
ctx.prec = prec
|
| 319 |
+
return result
|
| 320 |
+
|
| 321 |
+
def householder(ctx, A):
|
| 322 |
+
"""
|
| 323 |
+
(A|b) -> H, p, x, res
|
| 324 |
+
|
| 325 |
+
(A|b) is the coefficient matrix with left hand side of an optionally
|
| 326 |
+
overdetermined linear equation system.
|
| 327 |
+
H and p contain all information about the transformation matrices.
|
| 328 |
+
x is the solution, res the residual.
|
| 329 |
+
"""
|
| 330 |
+
if not isinstance(A, ctx.matrix):
|
| 331 |
+
raise TypeError("A should be a type of ctx.matrix")
|
| 332 |
+
m = A.rows
|
| 333 |
+
n = A.cols
|
| 334 |
+
if m < n - 1:
|
| 335 |
+
raise RuntimeError("Columns should not be less than rows")
|
| 336 |
+
# calculate Householder matrix
|
| 337 |
+
p = []
|
| 338 |
+
for j in xrange(0, n - 1):
|
| 339 |
+
s = ctx.fsum(abs(A[i,j])**2 for i in xrange(j, m))
|
| 340 |
+
if not abs(s) > ctx.eps:
|
| 341 |
+
raise ValueError('matrix is numerically singular')
|
| 342 |
+
p.append(-ctx.sign(ctx.re(A[j,j])) * ctx.sqrt(s))
|
| 343 |
+
kappa = ctx.one / (s - p[j] * A[j,j])
|
| 344 |
+
A[j,j] -= p[j]
|
| 345 |
+
for k in xrange(j+1, n):
|
| 346 |
+
y = ctx.fsum(ctx.conj(A[i,j]) * A[i,k] for i in xrange(j, m)) * kappa
|
| 347 |
+
for i in xrange(j, m):
|
| 348 |
+
A[i,k] -= A[i,j] * y
|
| 349 |
+
# solve Rx = c1
|
| 350 |
+
x = [A[i,n - 1] for i in xrange(n - 1)]
|
| 351 |
+
for i in xrange(n - 2, -1, -1):
|
| 352 |
+
x[i] -= ctx.fsum(A[i,j] * x[j] for j in xrange(i + 1, n - 1))
|
| 353 |
+
x[i] /= p[i]
|
| 354 |
+
# calculate residual
|
| 355 |
+
if not m == n - 1:
|
| 356 |
+
r = [A[m-1-i, n-1] for i in xrange(m - n + 1)]
|
| 357 |
+
else:
|
| 358 |
+
# determined system, residual should be 0
|
| 359 |
+
r = [0]*m # maybe a bad idea, changing r[i] will change all elements
|
| 360 |
+
return A, p, x, r
|
| 361 |
+
|
| 362 |
+
#def qr(ctx, A):
|
| 363 |
+
# """
|
| 364 |
+
# A -> Q, R
|
| 365 |
+
#
|
| 366 |
+
# QR factorisation of a square matrix A using Householder decomposition.
|
| 367 |
+
# Q is orthogonal, this leads to very few numerical errors.
|
| 368 |
+
#
|
| 369 |
+
# A = Q*R
|
| 370 |
+
# """
|
| 371 |
+
# H, p, x, res = householder(A)
|
| 372 |
+
# TODO: implement this
|
| 373 |
+
|
| 374 |
+
def residual(ctx, A, x, b, **kwargs):
|
| 375 |
+
"""
|
| 376 |
+
Calculate the residual of a solution to a linear equation system.
|
| 377 |
+
|
| 378 |
+
r = A*x - b for A*x = b
|
| 379 |
+
"""
|
| 380 |
+
oldprec = ctx.prec
|
| 381 |
+
try:
|
| 382 |
+
ctx.prec *= 2
|
| 383 |
+
A, x, b = ctx.matrix(A, **kwargs), ctx.matrix(x, **kwargs), ctx.matrix(b, **kwargs)
|
| 384 |
+
return A*x - b
|
| 385 |
+
finally:
|
| 386 |
+
ctx.prec = oldprec
|
| 387 |
+
|
| 388 |
+
def qr_solve(ctx, A, b, norm=None, **kwargs):
|
| 389 |
+
"""
|
| 390 |
+
Ax = b => x, ||Ax - b||
|
| 391 |
+
|
| 392 |
+
Solve a determined or overdetermined linear equations system and
|
| 393 |
+
calculate the norm of the residual (error).
|
| 394 |
+
QR decomposition using Householder factorization is applied, which gives very
|
| 395 |
+
accurate results even for ill-conditioned matrices. qr_solve is twice as
|
| 396 |
+
efficient.
|
| 397 |
+
"""
|
| 398 |
+
if norm is None:
|
| 399 |
+
norm = ctx.norm
|
| 400 |
+
prec = ctx.prec
|
| 401 |
+
try:
|
| 402 |
+
ctx.prec += 10
|
| 403 |
+
# do not overwrite A nor b
|
| 404 |
+
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
| 405 |
+
if A.rows < A.cols:
|
| 406 |
+
raise ValueError('cannot solve underdetermined system')
|
| 407 |
+
H, p, x, r = ctx.householder(ctx.extend(A, b))
|
| 408 |
+
res = ctx.norm(r)
|
| 409 |
+
# calculate residual "manually" for determined systems
|
| 410 |
+
if res == 0:
|
| 411 |
+
res = ctx.norm(ctx.residual(A, x, b))
|
| 412 |
+
return ctx.matrix(x, **kwargs), res
|
| 413 |
+
finally:
|
| 414 |
+
ctx.prec = prec
|
| 415 |
+
|
| 416 |
+
def cholesky(ctx, A, tol=None):
|
| 417 |
+
r"""
|
| 418 |
+
Cholesky decomposition of a symmetric positive-definite matrix `A`.
|
| 419 |
+
Returns a lower triangular matrix `L` such that `A = L \times L^T`.
|
| 420 |
+
More generally, for a complex Hermitian positive-definite matrix,
|
| 421 |
+
a Cholesky decomposition satisfying `A = L \times L^H` is returned.
|
| 422 |
+
|
| 423 |
+
The Cholesky decomposition can be used to solve linear equation
|
| 424 |
+
systems twice as efficiently as LU decomposition, or to
|
| 425 |
+
test whether `A` is positive-definite.
|
| 426 |
+
|
| 427 |
+
The optional parameter ``tol`` determines the tolerance for
|
| 428 |
+
verifying positive-definiteness.
|
| 429 |
+
|
| 430 |
+
**Examples**
|
| 431 |
+
|
| 432 |
+
Cholesky decomposition of a positive-definite symmetric matrix::
|
| 433 |
+
|
| 434 |
+
>>> from mpmath import *
|
| 435 |
+
>>> mp.dps = 25; mp.pretty = True
|
| 436 |
+
>>> A = eye(3) + hilbert(3)
|
| 437 |
+
>>> nprint(A)
|
| 438 |
+
[ 2.0 0.5 0.333333]
|
| 439 |
+
[ 0.5 1.33333 0.25]
|
| 440 |
+
[0.333333 0.25 1.2]
|
| 441 |
+
>>> L = cholesky(A)
|
| 442 |
+
>>> nprint(L)
|
| 443 |
+
[ 1.41421 0.0 0.0]
|
| 444 |
+
[0.353553 1.09924 0.0]
|
| 445 |
+
[0.235702 0.15162 1.05899]
|
| 446 |
+
>>> chop(A - L*L.T)
|
| 447 |
+
[0.0 0.0 0.0]
|
| 448 |
+
[0.0 0.0 0.0]
|
| 449 |
+
[0.0 0.0 0.0]
|
| 450 |
+
|
| 451 |
+
Cholesky decomposition of a Hermitian matrix::
|
| 452 |
+
|
| 453 |
+
>>> A = eye(3) + matrix([[0,0.25j,-0.5j],[-0.25j,0,0],[0.5j,0,0]])
|
| 454 |
+
>>> L = cholesky(A)
|
| 455 |
+
>>> nprint(L)
|
| 456 |
+
[ 1.0 0.0 0.0]
|
| 457 |
+
[(0.0 - 0.25j) (0.968246 + 0.0j) 0.0]
|
| 458 |
+
[ (0.0 + 0.5j) (0.129099 + 0.0j) (0.856349 + 0.0j)]
|
| 459 |
+
>>> chop(A - L*L.H)
|
| 460 |
+
[0.0 0.0 0.0]
|
| 461 |
+
[0.0 0.0 0.0]
|
| 462 |
+
[0.0 0.0 0.0]
|
| 463 |
+
|
| 464 |
+
Attempted Cholesky decomposition of a matrix that is not positive
|
| 465 |
+
definite::
|
| 466 |
+
|
| 467 |
+
>>> A = -eye(3) + hilbert(3)
|
| 468 |
+
>>> L = cholesky(A)
|
| 469 |
+
Traceback (most recent call last):
|
| 470 |
+
...
|
| 471 |
+
ValueError: matrix is not positive-definite
|
| 472 |
+
|
| 473 |
+
**References**
|
| 474 |
+
|
| 475 |
+
1. [Wikipedia]_ http://en.wikipedia.org/wiki/Cholesky_decomposition
|
| 476 |
+
|
| 477 |
+
"""
|
| 478 |
+
if not isinstance(A, ctx.matrix):
|
| 479 |
+
raise RuntimeError("A should be a type of ctx.matrix")
|
| 480 |
+
if not A.rows == A.cols:
|
| 481 |
+
raise ValueError('need n*n matrix')
|
| 482 |
+
if tol is None:
|
| 483 |
+
tol = +ctx.eps
|
| 484 |
+
n = A.rows
|
| 485 |
+
L = ctx.matrix(n)
|
| 486 |
+
for j in xrange(n):
|
| 487 |
+
c = ctx.re(A[j,j])
|
| 488 |
+
if abs(c-A[j,j]) > tol:
|
| 489 |
+
raise ValueError('matrix is not Hermitian')
|
| 490 |
+
s = c - ctx.fsum((L[j,k] for k in xrange(j)),
|
| 491 |
+
absolute=True, squared=True)
|
| 492 |
+
if s < tol:
|
| 493 |
+
raise ValueError('matrix is not positive-definite')
|
| 494 |
+
L[j,j] = ctx.sqrt(s)
|
| 495 |
+
for i in xrange(j, n):
|
| 496 |
+
it1 = (L[i,k] for k in xrange(j))
|
| 497 |
+
it2 = (L[j,k] for k in xrange(j))
|
| 498 |
+
t = ctx.fdot(it1, it2, conjugate=True)
|
| 499 |
+
L[i,j] = (A[i,j] - t) / L[j,j]
|
| 500 |
+
return L
|
| 501 |
+
|
| 502 |
+
def cholesky_solve(ctx, A, b, **kwargs):
|
| 503 |
+
"""
|
| 504 |
+
Ax = b => x
|
| 505 |
+
|
| 506 |
+
Solve a symmetric positive-definite linear equation system.
|
| 507 |
+
This is twice as efficient as lu_solve.
|
| 508 |
+
|
| 509 |
+
Typical use cases:
|
| 510 |
+
* A.T*A
|
| 511 |
+
* Hessian matrix
|
| 512 |
+
* differential equations
|
| 513 |
+
"""
|
| 514 |
+
prec = ctx.prec
|
| 515 |
+
try:
|
| 516 |
+
ctx.prec += 10
|
| 517 |
+
# do not overwrite A nor b
|
| 518 |
+
A, b = ctx.matrix(A, **kwargs).copy(), ctx.matrix(b, **kwargs).copy()
|
| 519 |
+
if A.rows != A.cols:
|
| 520 |
+
raise ValueError('can only solve determined system')
|
| 521 |
+
# Cholesky factorization
|
| 522 |
+
L = ctx.cholesky(A)
|
| 523 |
+
# solve
|
| 524 |
+
n = L.rows
|
| 525 |
+
if len(b) != n:
|
| 526 |
+
raise ValueError("Value should be equal to n")
|
| 527 |
+
for i in xrange(n):
|
| 528 |
+
b[i] -= ctx.fsum(L[i,j] * b[j] for j in xrange(i))
|
| 529 |
+
b[i] /= L[i,i]
|
| 530 |
+
x = ctx.U_solve(L.T, b)
|
| 531 |
+
return x
|
| 532 |
+
finally:
|
| 533 |
+
ctx.prec = prec
|
| 534 |
+
|
| 535 |
+
def det(ctx, A):
|
| 536 |
+
"""
|
| 537 |
+
Calculate the determinant of a matrix.
|
| 538 |
+
"""
|
| 539 |
+
prec = ctx.prec
|
| 540 |
+
try:
|
| 541 |
+
# do not overwrite A
|
| 542 |
+
A = ctx.matrix(A).copy()
|
| 543 |
+
# use LU factorization to calculate determinant
|
| 544 |
+
try:
|
| 545 |
+
R, p = ctx.LU_decomp(A)
|
| 546 |
+
except ZeroDivisionError:
|
| 547 |
+
return 0
|
| 548 |
+
z = 1
|
| 549 |
+
for i, e in enumerate(p):
|
| 550 |
+
if i != e:
|
| 551 |
+
z *= -1
|
| 552 |
+
for i in xrange(A.rows):
|
| 553 |
+
z *= R[i,i]
|
| 554 |
+
return z
|
| 555 |
+
finally:
|
| 556 |
+
ctx.prec = prec
|
| 557 |
+
|
| 558 |
+
def cond(ctx, A, norm=None):
|
| 559 |
+
"""
|
| 560 |
+
Calculate the condition number of a matrix using a specified matrix norm.
|
| 561 |
+
|
| 562 |
+
The condition number estimates the sensitivity of a matrix to errors.
|
| 563 |
+
Example: small input errors for ill-conditioned coefficient matrices
|
| 564 |
+
alter the solution of the system dramatically.
|
| 565 |
+
|
| 566 |
+
For ill-conditioned matrices it's recommended to use qr_solve() instead
|
| 567 |
+
of lu_solve(). This does not help with input errors however, it just avoids
|
| 568 |
+
to add additional errors.
|
| 569 |
+
|
| 570 |
+
Definition: cond(A) = ||A|| * ||A**-1||
|
| 571 |
+
"""
|
| 572 |
+
if norm is None:
|
| 573 |
+
norm = lambda x: ctx.mnorm(x,1)
|
| 574 |
+
return norm(A) * norm(ctx.inverse(A))
|
| 575 |
+
|
| 576 |
+
def lu_solve_mat(ctx, a, b):
|
| 577 |
+
"""Solve a * x = b where a and b are matrices."""
|
| 578 |
+
r = ctx.matrix(a.rows, b.cols)
|
| 579 |
+
for i in range(b.cols):
|
| 580 |
+
c = ctx.lu_solve(a, b.column(i))
|
| 581 |
+
for j in range(len(c)):
|
| 582 |
+
r[j, i] = c[j]
|
| 583 |
+
return r
|
| 584 |
+
|
| 585 |
+
def qr(ctx, A, mode = 'full', edps = 10):
|
| 586 |
+
"""
|
| 587 |
+
Compute a QR factorization $A = QR$ where
|
| 588 |
+
A is an m x n matrix of real or complex numbers where m >= n
|
| 589 |
+
|
| 590 |
+
mode has following meanings:
|
| 591 |
+
(1) mode = 'raw' returns two matrixes (A, tau) in the
|
| 592 |
+
internal format used by LAPACK
|
| 593 |
+
(2) mode = 'skinny' returns the leading n columns of Q
|
| 594 |
+
and n rows of R
|
| 595 |
+
(3) Any other value returns the leading m columns of Q
|
| 596 |
+
and m rows of R
|
| 597 |
+
|
| 598 |
+
edps is the increase in mp precision used for calculations
|
| 599 |
+
|
| 600 |
+
**Examples**
|
| 601 |
+
|
| 602 |
+
>>> from mpmath import *
|
| 603 |
+
>>> mp.dps = 15
|
| 604 |
+
>>> mp.pretty = True
|
| 605 |
+
>>> A = matrix([[1, 2], [3, 4], [1, 1]])
|
| 606 |
+
>>> Q, R = qr(A)
|
| 607 |
+
>>> Q
|
| 608 |
+
[-0.301511344577764 0.861640436855329 0.408248290463863]
|
| 609 |
+
[-0.904534033733291 -0.123091490979333 -0.408248290463863]
|
| 610 |
+
[-0.301511344577764 -0.492365963917331 0.816496580927726]
|
| 611 |
+
>>> R
|
| 612 |
+
[-3.3166247903554 -4.52267016866645]
|
| 613 |
+
[ 0.0 0.738548945875996]
|
| 614 |
+
[ 0.0 0.0]
|
| 615 |
+
>>> Q * R
|
| 616 |
+
[1.0 2.0]
|
| 617 |
+
[3.0 4.0]
|
| 618 |
+
[1.0 1.0]
|
| 619 |
+
>>> chop(Q.T * Q)
|
| 620 |
+
[1.0 0.0 0.0]
|
| 621 |
+
[0.0 1.0 0.0]
|
| 622 |
+
[0.0 0.0 1.0]
|
| 623 |
+
>>> B = matrix([[1+0j, 2-3j], [3+j, 4+5j]])
|
| 624 |
+
>>> Q, R = qr(B)
|
| 625 |
+
>>> nprint(Q)
|
| 626 |
+
[ (-0.301511 + 0.0j) (0.0695795 - 0.95092j)]
|
| 627 |
+
[(-0.904534 - 0.301511j) (-0.115966 + 0.278318j)]
|
| 628 |
+
>>> nprint(R)
|
| 629 |
+
[(-3.31662 + 0.0j) (-5.72872 - 2.41209j)]
|
| 630 |
+
[ 0.0 (3.91965 + 0.0j)]
|
| 631 |
+
>>> Q * R
|
| 632 |
+
[(1.0 + 0.0j) (2.0 - 3.0j)]
|
| 633 |
+
[(3.0 + 1.0j) (4.0 + 5.0j)]
|
| 634 |
+
>>> chop(Q.T * Q.conjugate())
|
| 635 |
+
[1.0 0.0]
|
| 636 |
+
[0.0 1.0]
|
| 637 |
+
|
| 638 |
+
"""
|
| 639 |
+
|
| 640 |
+
# check values before continuing
|
| 641 |
+
assert isinstance(A, ctx.matrix)
|
| 642 |
+
m = A.rows
|
| 643 |
+
n = A.cols
|
| 644 |
+
assert n >= 0
|
| 645 |
+
assert m >= n
|
| 646 |
+
assert edps >= 0
|
| 647 |
+
|
| 648 |
+
# check for complex data type
|
| 649 |
+
cmplx = any(type(x) is ctx.mpc for x in A)
|
| 650 |
+
|
| 651 |
+
# temporarily increase the precision and initialize
|
| 652 |
+
with ctx.extradps(edps):
|
| 653 |
+
tau = ctx.matrix(n,1)
|
| 654 |
+
A = A.copy()
|
| 655 |
+
|
| 656 |
+
# ---------------
|
| 657 |
+
# FACTOR MATRIX A
|
| 658 |
+
# ---------------
|
| 659 |
+
if cmplx:
|
| 660 |
+
one = ctx.mpc('1.0', '0.0')
|
| 661 |
+
zero = ctx.mpc('0.0', '0.0')
|
| 662 |
+
rzero = ctx.mpf('0.0')
|
| 663 |
+
|
| 664 |
+
# main loop to factor A (complex)
|
| 665 |
+
for j in xrange(0, n):
|
| 666 |
+
alpha = A[j,j]
|
| 667 |
+
alphr = ctx.re(alpha)
|
| 668 |
+
alphi = ctx.im(alpha)
|
| 669 |
+
|
| 670 |
+
if (m-j) >= 2:
|
| 671 |
+
xnorm = ctx.fsum( A[i,j]*ctx.conj(A[i,j]) for i in xrange(j+1, m) )
|
| 672 |
+
xnorm = ctx.re( ctx.sqrt(xnorm) )
|
| 673 |
+
else:
|
| 674 |
+
xnorm = rzero
|
| 675 |
+
|
| 676 |
+
if (xnorm == rzero) and (alphi == rzero):
|
| 677 |
+
tau[j] = zero
|
| 678 |
+
continue
|
| 679 |
+
|
| 680 |
+
if alphr < rzero:
|
| 681 |
+
beta = ctx.sqrt(alphr**2 + alphi**2 + xnorm**2)
|
| 682 |
+
else:
|
| 683 |
+
beta = -ctx.sqrt(alphr**2 + alphi**2 + xnorm**2)
|
| 684 |
+
|
| 685 |
+
tau[j] = ctx.mpc( (beta - alphr) / beta, -alphi / beta )
|
| 686 |
+
t = -ctx.conj(tau[j])
|
| 687 |
+
za = one / (alpha - beta)
|
| 688 |
+
|
| 689 |
+
for i in xrange(j+1, m):
|
| 690 |
+
A[i,j] *= za
|
| 691 |
+
|
| 692 |
+
A[j,j] = one
|
| 693 |
+
for k in xrange(j+1, n):
|
| 694 |
+
y = ctx.fsum(A[i,j] * ctx.conj(A[i,k]) for i in xrange(j, m))
|
| 695 |
+
temp = t * ctx.conj(y)
|
| 696 |
+
for i in xrange(j, m):
|
| 697 |
+
A[i,k] += A[i,j] * temp
|
| 698 |
+
|
| 699 |
+
A[j,j] = ctx.mpc(beta, '0.0')
|
| 700 |
+
else:
|
| 701 |
+
one = ctx.mpf('1.0')
|
| 702 |
+
zero = ctx.mpf('0.0')
|
| 703 |
+
|
| 704 |
+
# main loop to factor A (real)
|
| 705 |
+
for j in xrange(0, n):
|
| 706 |
+
alpha = A[j,j]
|
| 707 |
+
|
| 708 |
+
if (m-j) > 2:
|
| 709 |
+
xnorm = ctx.fsum( (A[i,j])**2 for i in xrange(j+1, m) )
|
| 710 |
+
xnorm = ctx.sqrt(xnorm)
|
| 711 |
+
elif (m-j) == 2:
|
| 712 |
+
xnorm = abs( A[m-1,j] )
|
| 713 |
+
else:
|
| 714 |
+
xnorm = zero
|
| 715 |
+
|
| 716 |
+
if xnorm == zero:
|
| 717 |
+
tau[j] = zero
|
| 718 |
+
continue
|
| 719 |
+
|
| 720 |
+
if alpha < zero:
|
| 721 |
+
beta = ctx.sqrt(alpha**2 + xnorm**2)
|
| 722 |
+
else:
|
| 723 |
+
beta = -ctx.sqrt(alpha**2 + xnorm**2)
|
| 724 |
+
|
| 725 |
+
tau[j] = (beta - alpha) / beta
|
| 726 |
+
t = -tau[j]
|
| 727 |
+
da = one / (alpha - beta)
|
| 728 |
+
|
| 729 |
+
for i in xrange(j+1, m):
|
| 730 |
+
A[i,j] *= da
|
| 731 |
+
|
| 732 |
+
A[j,j] = one
|
| 733 |
+
for k in xrange(j+1, n):
|
| 734 |
+
y = ctx.fsum( A[i,j] * A[i,k] for i in xrange(j, m) )
|
| 735 |
+
temp = t * y
|
| 736 |
+
for i in xrange(j,m):
|
| 737 |
+
A[i,k] += A[i,j] * temp
|
| 738 |
+
|
| 739 |
+
A[j,j] = beta
|
| 740 |
+
|
| 741 |
+
# return factorization in same internal format as LAPACK
|
| 742 |
+
if (mode == 'raw') or (mode == 'RAW'):
|
| 743 |
+
return A, tau
|
| 744 |
+
|
| 745 |
+
# ----------------------------------
|
| 746 |
+
# FORM Q USING BACKWARD ACCUMULATION
|
| 747 |
+
# ----------------------------------
|
| 748 |
+
|
| 749 |
+
# form R before the values are overwritten
|
| 750 |
+
R = A.copy()
|
| 751 |
+
for j in xrange(0, n):
|
| 752 |
+
for i in xrange(j+1, m):
|
| 753 |
+
R[i,j] = zero
|
| 754 |
+
|
| 755 |
+
# set the value of p (number of columns of Q to return)
|
| 756 |
+
p = m
|
| 757 |
+
if (mode == 'skinny') or (mode == 'SKINNY'):
|
| 758 |
+
p = n
|
| 759 |
+
|
| 760 |
+
# add columns to A if needed and initialize
|
| 761 |
+
A.cols += (p-n)
|
| 762 |
+
for j in xrange(0, p):
|
| 763 |
+
A[j,j] = one
|
| 764 |
+
for i in xrange(0, j):
|
| 765 |
+
A[i,j] = zero
|
| 766 |
+
|
| 767 |
+
# main loop to form Q
|
| 768 |
+
for j in xrange(n-1, -1, -1):
|
| 769 |
+
t = -tau[j]
|
| 770 |
+
A[j,j] += t
|
| 771 |
+
|
| 772 |
+
for k in xrange(j+1, p):
|
| 773 |
+
if cmplx:
|
| 774 |
+
y = ctx.fsum(A[i,j] * ctx.conj(A[i,k]) for i in xrange(j+1, m))
|
| 775 |
+
temp = t * ctx.conj(y)
|
| 776 |
+
else:
|
| 777 |
+
y = ctx.fsum(A[i,j] * A[i,k] for i in xrange(j+1, m))
|
| 778 |
+
temp = t * y
|
| 779 |
+
A[j,k] = temp
|
| 780 |
+
for i in xrange(j+1, m):
|
| 781 |
+
A[i,k] += A[i,j] * temp
|
| 782 |
+
|
| 783 |
+
for i in xrange(j+1, m):
|
| 784 |
+
A[i, j] *= t
|
| 785 |
+
|
| 786 |
+
return A, R[0:p,0:n]
|
| 787 |
+
|
| 788 |
+
# ------------------
|
| 789 |
+
# END OF FUNCTION QR
|
| 790 |
+
# ------------------
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/matrices/matrices.py
ADDED
|
@@ -0,0 +1,1005 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..libmp.backend import xrange
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
# TODO: interpret list as vectors (for multiplication)
|
| 5 |
+
|
| 6 |
+
rowsep = '\n'
|
| 7 |
+
colsep = ' '
|
| 8 |
+
|
| 9 |
+
class _matrix(object):
|
| 10 |
+
"""
|
| 11 |
+
Numerical matrix.
|
| 12 |
+
|
| 13 |
+
Specify the dimensions or the data as a nested list.
|
| 14 |
+
Elements default to zero.
|
| 15 |
+
Use a flat list to create a column vector easily.
|
| 16 |
+
|
| 17 |
+
The datatype of the context (mpf for mp, mpi for iv, and float for fp) is used to store the data.
|
| 18 |
+
|
| 19 |
+
Creating matrices
|
| 20 |
+
-----------------
|
| 21 |
+
|
| 22 |
+
Matrices in mpmath are implemented using dictionaries. Only non-zero values
|
| 23 |
+
are stored, so it is cheap to represent sparse matrices.
|
| 24 |
+
|
| 25 |
+
The most basic way to create one is to use the ``matrix`` class directly.
|
| 26 |
+
You can create an empty matrix specifying the dimensions:
|
| 27 |
+
|
| 28 |
+
>>> from mpmath import *
|
| 29 |
+
>>> mp.dps = 15
|
| 30 |
+
>>> matrix(2)
|
| 31 |
+
matrix(
|
| 32 |
+
[['0.0', '0.0'],
|
| 33 |
+
['0.0', '0.0']])
|
| 34 |
+
>>> matrix(2, 3)
|
| 35 |
+
matrix(
|
| 36 |
+
[['0.0', '0.0', '0.0'],
|
| 37 |
+
['0.0', '0.0', '0.0']])
|
| 38 |
+
|
| 39 |
+
Calling ``matrix`` with one dimension will create a square matrix.
|
| 40 |
+
|
| 41 |
+
To access the dimensions of a matrix, use the ``rows`` or ``cols`` keyword:
|
| 42 |
+
|
| 43 |
+
>>> A = matrix(3, 2)
|
| 44 |
+
>>> A
|
| 45 |
+
matrix(
|
| 46 |
+
[['0.0', '0.0'],
|
| 47 |
+
['0.0', '0.0'],
|
| 48 |
+
['0.0', '0.0']])
|
| 49 |
+
>>> A.rows
|
| 50 |
+
3
|
| 51 |
+
>>> A.cols
|
| 52 |
+
2
|
| 53 |
+
|
| 54 |
+
You can also change the dimension of an existing matrix. This will set the
|
| 55 |
+
new elements to 0. If the new dimension is smaller than before, the
|
| 56 |
+
concerning elements are discarded:
|
| 57 |
+
|
| 58 |
+
>>> A.rows = 2
|
| 59 |
+
>>> A
|
| 60 |
+
matrix(
|
| 61 |
+
[['0.0', '0.0'],
|
| 62 |
+
['0.0', '0.0']])
|
| 63 |
+
|
| 64 |
+
Internally ``mpmathify`` is used every time an element is set. This
|
| 65 |
+
is done using the syntax A[row,column], counting from 0:
|
| 66 |
+
|
| 67 |
+
>>> A = matrix(2)
|
| 68 |
+
>>> A[1,1] = 1 + 1j
|
| 69 |
+
>>> A
|
| 70 |
+
matrix(
|
| 71 |
+
[['0.0', '0.0'],
|
| 72 |
+
['0.0', mpc(real='1.0', imag='1.0')]])
|
| 73 |
+
|
| 74 |
+
A more comfortable way to create a matrix lets you use nested lists:
|
| 75 |
+
|
| 76 |
+
>>> matrix([[1, 2], [3, 4]])
|
| 77 |
+
matrix(
|
| 78 |
+
[['1.0', '2.0'],
|
| 79 |
+
['3.0', '4.0']])
|
| 80 |
+
|
| 81 |
+
Convenient advanced functions are available for creating various standard
|
| 82 |
+
matrices, see ``zeros``, ``ones``, ``diag``, ``eye``, ``randmatrix`` and
|
| 83 |
+
``hilbert``.
|
| 84 |
+
|
| 85 |
+
Vectors
|
| 86 |
+
.......
|
| 87 |
+
|
| 88 |
+
Vectors may also be represented by the ``matrix`` class (with rows = 1 or cols = 1).
|
| 89 |
+
For vectors there are some things which make life easier. A column vector can
|
| 90 |
+
be created using a flat list, a row vectors using an almost flat nested list::
|
| 91 |
+
|
| 92 |
+
>>> matrix([1, 2, 3])
|
| 93 |
+
matrix(
|
| 94 |
+
[['1.0'],
|
| 95 |
+
['2.0'],
|
| 96 |
+
['3.0']])
|
| 97 |
+
>>> matrix([[1, 2, 3]])
|
| 98 |
+
matrix(
|
| 99 |
+
[['1.0', '2.0', '3.0']])
|
| 100 |
+
|
| 101 |
+
Optionally vectors can be accessed like lists, using only a single index::
|
| 102 |
+
|
| 103 |
+
>>> x = matrix([1, 2, 3])
|
| 104 |
+
>>> x[1]
|
| 105 |
+
mpf('2.0')
|
| 106 |
+
>>> x[1,0]
|
| 107 |
+
mpf('2.0')
|
| 108 |
+
|
| 109 |
+
Other
|
| 110 |
+
.....
|
| 111 |
+
|
| 112 |
+
Like you probably expected, matrices can be printed::
|
| 113 |
+
|
| 114 |
+
>>> print randmatrix(3) # doctest:+SKIP
|
| 115 |
+
[ 0.782963853573023 0.802057689719883 0.427895717335467]
|
| 116 |
+
[0.0541876859348597 0.708243266653103 0.615134039977379]
|
| 117 |
+
[ 0.856151514955773 0.544759264818486 0.686210904770947]
|
| 118 |
+
|
| 119 |
+
Use ``nstr`` or ``nprint`` to specify the number of digits to print::
|
| 120 |
+
|
| 121 |
+
>>> nprint(randmatrix(5), 3) # doctest:+SKIP
|
| 122 |
+
[2.07e-1 1.66e-1 5.06e-1 1.89e-1 8.29e-1]
|
| 123 |
+
[6.62e-1 6.55e-1 4.47e-1 4.82e-1 2.06e-2]
|
| 124 |
+
[4.33e-1 7.75e-1 6.93e-2 2.86e-1 5.71e-1]
|
| 125 |
+
[1.01e-1 2.53e-1 6.13e-1 3.32e-1 2.59e-1]
|
| 126 |
+
[1.56e-1 7.27e-2 6.05e-1 6.67e-2 2.79e-1]
|
| 127 |
+
|
| 128 |
+
As matrices are mutable, you will need to copy them sometimes::
|
| 129 |
+
|
| 130 |
+
>>> A = matrix(2)
|
| 131 |
+
>>> A
|
| 132 |
+
matrix(
|
| 133 |
+
[['0.0', '0.0'],
|
| 134 |
+
['0.0', '0.0']])
|
| 135 |
+
>>> B = A.copy()
|
| 136 |
+
>>> B[0,0] = 1
|
| 137 |
+
>>> B
|
| 138 |
+
matrix(
|
| 139 |
+
[['1.0', '0.0'],
|
| 140 |
+
['0.0', '0.0']])
|
| 141 |
+
>>> A
|
| 142 |
+
matrix(
|
| 143 |
+
[['0.0', '0.0'],
|
| 144 |
+
['0.0', '0.0']])
|
| 145 |
+
|
| 146 |
+
Finally, it is possible to convert a matrix to a nested list. This is very useful,
|
| 147 |
+
as most Python libraries involving matrices or arrays (namely NumPy or SymPy)
|
| 148 |
+
support this format::
|
| 149 |
+
|
| 150 |
+
>>> B.tolist()
|
| 151 |
+
[[mpf('1.0'), mpf('0.0')], [mpf('0.0'), mpf('0.0')]]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
Matrix operations
|
| 155 |
+
-----------------
|
| 156 |
+
|
| 157 |
+
You can add and subtract matrices of compatible dimensions::
|
| 158 |
+
|
| 159 |
+
>>> A = matrix([[1, 2], [3, 4]])
|
| 160 |
+
>>> B = matrix([[-2, 4], [5, 9]])
|
| 161 |
+
>>> A + B
|
| 162 |
+
matrix(
|
| 163 |
+
[['-1.0', '6.0'],
|
| 164 |
+
['8.0', '13.0']])
|
| 165 |
+
>>> A - B
|
| 166 |
+
matrix(
|
| 167 |
+
[['3.0', '-2.0'],
|
| 168 |
+
['-2.0', '-5.0']])
|
| 169 |
+
>>> A + ones(3) # doctest:+ELLIPSIS
|
| 170 |
+
Traceback (most recent call last):
|
| 171 |
+
...
|
| 172 |
+
ValueError: incompatible dimensions for addition
|
| 173 |
+
|
| 174 |
+
It is possible to multiply or add matrices and scalars. In the latter case the
|
| 175 |
+
operation will be done element-wise::
|
| 176 |
+
|
| 177 |
+
>>> A * 2
|
| 178 |
+
matrix(
|
| 179 |
+
[['2.0', '4.0'],
|
| 180 |
+
['6.0', '8.0']])
|
| 181 |
+
>>> A / 4
|
| 182 |
+
matrix(
|
| 183 |
+
[['0.25', '0.5'],
|
| 184 |
+
['0.75', '1.0']])
|
| 185 |
+
>>> A - 1
|
| 186 |
+
matrix(
|
| 187 |
+
[['0.0', '1.0'],
|
| 188 |
+
['2.0', '3.0']])
|
| 189 |
+
|
| 190 |
+
Of course you can perform matrix multiplication, if the dimensions are
|
| 191 |
+
compatible, using ``@`` (for Python >= 3.5) or ``*``. For clarity, ``@`` is
|
| 192 |
+
recommended (`PEP 465 <https://www.python.org/dev/peps/pep-0465/>`), because
|
| 193 |
+
the meaning of ``*`` is different in many other Python libraries such as NumPy.
|
| 194 |
+
|
| 195 |
+
>>> A @ B # doctest:+SKIP
|
| 196 |
+
matrix(
|
| 197 |
+
[['8.0', '22.0'],
|
| 198 |
+
['14.0', '48.0']])
|
| 199 |
+
>>> A * B # same as A @ B
|
| 200 |
+
matrix(
|
| 201 |
+
[['8.0', '22.0'],
|
| 202 |
+
['14.0', '48.0']])
|
| 203 |
+
>>> matrix([[1, 2, 3]]) * matrix([[-6], [7], [-2]])
|
| 204 |
+
matrix(
|
| 205 |
+
[['2.0']])
|
| 206 |
+
|
| 207 |
+
..
|
| 208 |
+
COMMENT: TODO: the above "doctest:+SKIP" may be removed as soon as we
|
| 209 |
+
have dropped support for Python 3.5 and below.
|
| 210 |
+
|
| 211 |
+
You can raise powers of square matrices::
|
| 212 |
+
|
| 213 |
+
>>> A**2
|
| 214 |
+
matrix(
|
| 215 |
+
[['7.0', '10.0'],
|
| 216 |
+
['15.0', '22.0']])
|
| 217 |
+
|
| 218 |
+
Negative powers will calculate the inverse::
|
| 219 |
+
|
| 220 |
+
>>> A**-1
|
| 221 |
+
matrix(
|
| 222 |
+
[['-2.0', '1.0'],
|
| 223 |
+
['1.5', '-0.5']])
|
| 224 |
+
>>> A * A**-1
|
| 225 |
+
matrix(
|
| 226 |
+
[['1.0', '1.0842021724855e-19'],
|
| 227 |
+
['-2.16840434497101e-19', '1.0']])
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
Matrix transposition is straightforward::
|
| 232 |
+
|
| 233 |
+
>>> A = ones(2, 3)
|
| 234 |
+
>>> A
|
| 235 |
+
matrix(
|
| 236 |
+
[['1.0', '1.0', '1.0'],
|
| 237 |
+
['1.0', '1.0', '1.0']])
|
| 238 |
+
>>> A.T
|
| 239 |
+
matrix(
|
| 240 |
+
[['1.0', '1.0'],
|
| 241 |
+
['1.0', '1.0'],
|
| 242 |
+
['1.0', '1.0']])
|
| 243 |
+
|
| 244 |
+
Norms
|
| 245 |
+
.....
|
| 246 |
+
|
| 247 |
+
Sometimes you need to know how "large" a matrix or vector is. Due to their
|
| 248 |
+
multidimensional nature it's not possible to compare them, but there are
|
| 249 |
+
several functions to map a matrix or a vector to a positive real number, the
|
| 250 |
+
so called norms.
|
| 251 |
+
|
| 252 |
+
For vectors the p-norm is intended, usually the 1-, the 2- and the oo-norm are
|
| 253 |
+
used.
|
| 254 |
+
|
| 255 |
+
>>> x = matrix([-10, 2, 100])
|
| 256 |
+
>>> norm(x, 1)
|
| 257 |
+
mpf('112.0')
|
| 258 |
+
>>> norm(x, 2)
|
| 259 |
+
mpf('100.5186549850325')
|
| 260 |
+
>>> norm(x, inf)
|
| 261 |
+
mpf('100.0')
|
| 262 |
+
|
| 263 |
+
Please note that the 2-norm is the most used one, though it is more expensive
|
| 264 |
+
to calculate than the 1- or oo-norm.
|
| 265 |
+
|
| 266 |
+
It is possible to generalize some vector norms to matrix norm::
|
| 267 |
+
|
| 268 |
+
>>> A = matrix([[1, -1000], [100, 50]])
|
| 269 |
+
>>> mnorm(A, 1)
|
| 270 |
+
mpf('1050.0')
|
| 271 |
+
>>> mnorm(A, inf)
|
| 272 |
+
mpf('1001.0')
|
| 273 |
+
>>> mnorm(A, 'F')
|
| 274 |
+
mpf('1006.2310867787777')
|
| 275 |
+
|
| 276 |
+
The last norm (the "Frobenius-norm") is an approximation for the 2-norm, which
|
| 277 |
+
is hard to calculate and not available. The Frobenius-norm lacks some
|
| 278 |
+
mathematical properties you might expect from a norm.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(self, *args, **kwargs):
|
| 282 |
+
self.__data = {}
|
| 283 |
+
# LU decompostion cache, this is useful when solving the same system
|
| 284 |
+
# multiple times, when calculating the inverse and when calculating the
|
| 285 |
+
# determinant
|
| 286 |
+
self._LU = None
|
| 287 |
+
if "force_type" in kwargs:
|
| 288 |
+
warnings.warn("The force_type argument was removed, it did not work"
|
| 289 |
+
" properly anyway. If you want to force floating-point or"
|
| 290 |
+
" interval computations, use the respective methods from `fp`"
|
| 291 |
+
" or `mp` instead, e.g., `fp.matrix()` or `iv.matrix()`."
|
| 292 |
+
" If you want to truncate values to integer, use .apply(int) instead.")
|
| 293 |
+
if isinstance(args[0], (list, tuple)):
|
| 294 |
+
if isinstance(args[0][0], (list, tuple)):
|
| 295 |
+
# interpret nested list as matrix
|
| 296 |
+
A = args[0]
|
| 297 |
+
self.__rows = len(A)
|
| 298 |
+
self.__cols = len(A[0])
|
| 299 |
+
for i, row in enumerate(A):
|
| 300 |
+
for j, a in enumerate(row):
|
| 301 |
+
# note: this will call __setitem__ which will call self.ctx.convert() to convert the datatype.
|
| 302 |
+
self[i, j] = a
|
| 303 |
+
else:
|
| 304 |
+
# interpret list as row vector
|
| 305 |
+
v = args[0]
|
| 306 |
+
self.__rows = len(v)
|
| 307 |
+
self.__cols = 1
|
| 308 |
+
for i, e in enumerate(v):
|
| 309 |
+
self[i, 0] = e
|
| 310 |
+
elif isinstance(args[0], int):
|
| 311 |
+
# create empty matrix of given dimensions
|
| 312 |
+
if len(args) == 1:
|
| 313 |
+
self.__rows = self.__cols = args[0]
|
| 314 |
+
else:
|
| 315 |
+
if not isinstance(args[1], int):
|
| 316 |
+
raise TypeError("expected int")
|
| 317 |
+
self.__rows = args[0]
|
| 318 |
+
self.__cols = args[1]
|
| 319 |
+
elif isinstance(args[0], _matrix):
|
| 320 |
+
A = args[0]
|
| 321 |
+
self.__rows = A._matrix__rows
|
| 322 |
+
self.__cols = A._matrix__cols
|
| 323 |
+
for i in xrange(A.__rows):
|
| 324 |
+
for j in xrange(A.__cols):
|
| 325 |
+
self[i, j] = A[i, j]
|
| 326 |
+
elif hasattr(args[0], 'tolist'):
|
| 327 |
+
A = self.ctx.matrix(args[0].tolist())
|
| 328 |
+
self.__data = A._matrix__data
|
| 329 |
+
self.__rows = A._matrix__rows
|
| 330 |
+
self.__cols = A._matrix__cols
|
| 331 |
+
else:
|
| 332 |
+
raise TypeError('could not interpret given arguments')
|
| 333 |
+
|
| 334 |
+
def apply(self, f):
|
| 335 |
+
"""
|
| 336 |
+
Return a copy of self with the function `f` applied elementwise.
|
| 337 |
+
"""
|
| 338 |
+
new = self.ctx.matrix(self.__rows, self.__cols)
|
| 339 |
+
for i in xrange(self.__rows):
|
| 340 |
+
for j in xrange(self.__cols):
|
| 341 |
+
new[i,j] = f(self[i,j])
|
| 342 |
+
return new
|
| 343 |
+
|
| 344 |
+
def __nstr__(self, n=None, **kwargs):
|
| 345 |
+
# Build table of string representations of the elements
|
| 346 |
+
res = []
|
| 347 |
+
# Track per-column max lengths for pretty alignment
|
| 348 |
+
maxlen = [0] * self.cols
|
| 349 |
+
for i in range(self.rows):
|
| 350 |
+
res.append([])
|
| 351 |
+
for j in range(self.cols):
|
| 352 |
+
if n:
|
| 353 |
+
string = self.ctx.nstr(self[i,j], n, **kwargs)
|
| 354 |
+
else:
|
| 355 |
+
string = str(self[i,j])
|
| 356 |
+
res[-1].append(string)
|
| 357 |
+
maxlen[j] = max(len(string), maxlen[j])
|
| 358 |
+
# Patch strings together
|
| 359 |
+
for i, row in enumerate(res):
|
| 360 |
+
for j, elem in enumerate(row):
|
| 361 |
+
# Pad each element up to maxlen so the columns line up
|
| 362 |
+
row[j] = elem.rjust(maxlen[j])
|
| 363 |
+
res[i] = "[" + colsep.join(row) + "]"
|
| 364 |
+
return rowsep.join(res)
|
| 365 |
+
|
| 366 |
+
def __str__(self):
|
| 367 |
+
return self.__nstr__()
|
| 368 |
+
|
| 369 |
+
def _toliststr(self, avoid_type=False):
|
| 370 |
+
"""
|
| 371 |
+
Create a list string from a matrix.
|
| 372 |
+
|
| 373 |
+
If avoid_type: avoid multiple 'mpf's.
|
| 374 |
+
"""
|
| 375 |
+
# XXX: should be something like self.ctx._types
|
| 376 |
+
typ = self.ctx.mpf
|
| 377 |
+
s = '['
|
| 378 |
+
for i in xrange(self.__rows):
|
| 379 |
+
s += '['
|
| 380 |
+
for j in xrange(self.__cols):
|
| 381 |
+
if not avoid_type or not isinstance(self[i,j], typ):
|
| 382 |
+
a = repr(self[i,j])
|
| 383 |
+
else:
|
| 384 |
+
a = "'" + str(self[i,j]) + "'"
|
| 385 |
+
s += a + ', '
|
| 386 |
+
s = s[:-2]
|
| 387 |
+
s += '],\n '
|
| 388 |
+
s = s[:-3]
|
| 389 |
+
s += ']'
|
| 390 |
+
return s
|
| 391 |
+
|
| 392 |
+
def tolist(self):
|
| 393 |
+
"""
|
| 394 |
+
Convert the matrix to a nested list.
|
| 395 |
+
"""
|
| 396 |
+
return [[self[i,j] for j in range(self.__cols)] for i in range(self.__rows)]
|
| 397 |
+
|
| 398 |
+
def __repr__(self):
|
| 399 |
+
if self.ctx.pretty:
|
| 400 |
+
return self.__str__()
|
| 401 |
+
s = 'matrix(\n'
|
| 402 |
+
s += self._toliststr(avoid_type=True) + ')'
|
| 403 |
+
return s
|
| 404 |
+
|
| 405 |
+
def __get_element(self, key):
|
| 406 |
+
'''
|
| 407 |
+
Fast extraction of the i,j element from the matrix
|
| 408 |
+
This function is for private use only because is unsafe:
|
| 409 |
+
1. Does not check on the value of key it expects key to be a integer tuple (i,j)
|
| 410 |
+
2. Does not check bounds
|
| 411 |
+
'''
|
| 412 |
+
if key in self.__data:
|
| 413 |
+
return self.__data[key]
|
| 414 |
+
else:
|
| 415 |
+
return self.ctx.zero
|
| 416 |
+
|
| 417 |
+
def __set_element(self, key, value):
|
| 418 |
+
'''
|
| 419 |
+
Fast assignment of the i,j element in the matrix
|
| 420 |
+
This function is unsafe:
|
| 421 |
+
1. Does not check on the value of key it expects key to be a integer tuple (i,j)
|
| 422 |
+
2. Does not check bounds
|
| 423 |
+
3. Does not check the value type
|
| 424 |
+
4. Does not reset the LU cache
|
| 425 |
+
'''
|
| 426 |
+
if value: # only store non-zeros
|
| 427 |
+
self.__data[key] = value
|
| 428 |
+
elif key in self.__data:
|
| 429 |
+
del self.__data[key]
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def __getitem__(self, key):
|
| 433 |
+
'''
|
| 434 |
+
Getitem function for mp matrix class with slice index enabled
|
| 435 |
+
it allows the following assingments
|
| 436 |
+
scalar to a slice of the matrix
|
| 437 |
+
B = A[:,2:6]
|
| 438 |
+
'''
|
| 439 |
+
# Convert vector to matrix indexing
|
| 440 |
+
if isinstance(key, int) or isinstance(key,slice):
|
| 441 |
+
# only sufficent for vectors
|
| 442 |
+
if self.__rows == 1:
|
| 443 |
+
key = (0, key)
|
| 444 |
+
elif self.__cols == 1:
|
| 445 |
+
key = (key, 0)
|
| 446 |
+
else:
|
| 447 |
+
raise IndexError('insufficient indices for matrix')
|
| 448 |
+
|
| 449 |
+
if isinstance(key[0],slice) or isinstance(key[1],slice):
|
| 450 |
+
|
| 451 |
+
#Rows
|
| 452 |
+
if isinstance(key[0],slice):
|
| 453 |
+
#Check bounds
|
| 454 |
+
if (key[0].start is None or key[0].start >= 0) and \
|
| 455 |
+
(key[0].stop is None or key[0].stop <= self.__rows+1):
|
| 456 |
+
# Generate indices
|
| 457 |
+
rows = xrange(*key[0].indices(self.__rows))
|
| 458 |
+
else:
|
| 459 |
+
raise IndexError('Row index out of bounds')
|
| 460 |
+
else:
|
| 461 |
+
# Single row
|
| 462 |
+
rows = [key[0]]
|
| 463 |
+
|
| 464 |
+
# Columns
|
| 465 |
+
if isinstance(key[1],slice):
|
| 466 |
+
# Check bounds
|
| 467 |
+
if (key[1].start is None or key[1].start >= 0) and \
|
| 468 |
+
(key[1].stop is None or key[1].stop <= self.__cols+1):
|
| 469 |
+
# Generate indices
|
| 470 |
+
columns = xrange(*key[1].indices(self.__cols))
|
| 471 |
+
else:
|
| 472 |
+
raise IndexError('Column index out of bounds')
|
| 473 |
+
|
| 474 |
+
else:
|
| 475 |
+
# Single column
|
| 476 |
+
columns = [key[1]]
|
| 477 |
+
|
| 478 |
+
# Create matrix slice
|
| 479 |
+
m = self.ctx.matrix(len(rows),len(columns))
|
| 480 |
+
|
| 481 |
+
# Assign elements to the output matrix
|
| 482 |
+
for i,x in enumerate(rows):
|
| 483 |
+
for j,y in enumerate(columns):
|
| 484 |
+
m.__set_element((i,j),self.__get_element((x,y)))
|
| 485 |
+
|
| 486 |
+
return m
|
| 487 |
+
|
| 488 |
+
else:
|
| 489 |
+
# single element extraction
|
| 490 |
+
if key[0] >= self.__rows or key[1] >= self.__cols:
|
| 491 |
+
raise IndexError('matrix index out of range')
|
| 492 |
+
if key in self.__data:
|
| 493 |
+
return self.__data[key]
|
| 494 |
+
else:
|
| 495 |
+
return self.ctx.zero
|
| 496 |
+
|
| 497 |
+
def __setitem__(self, key, value):
|
| 498 |
+
# setitem function for mp matrix class with slice index enabled
|
| 499 |
+
# it allows the following assingments
|
| 500 |
+
# scalar to a slice of the matrix
|
| 501 |
+
# A[:,2:6] = 2.5
|
| 502 |
+
# submatrix to matrix (the value matrix should be the same size as the slice size)
|
| 503 |
+
# A[3,:] = B where A is n x m and B is n x 1
|
| 504 |
+
# Convert vector to matrix indexing
|
| 505 |
+
if isinstance(key, int) or isinstance(key,slice):
|
| 506 |
+
# only sufficent for vectors
|
| 507 |
+
if self.__rows == 1:
|
| 508 |
+
key = (0, key)
|
| 509 |
+
elif self.__cols == 1:
|
| 510 |
+
key = (key, 0)
|
| 511 |
+
else:
|
| 512 |
+
raise IndexError('insufficient indices for matrix')
|
| 513 |
+
# Slice indexing
|
| 514 |
+
if isinstance(key[0],slice) or isinstance(key[1],slice):
|
| 515 |
+
# Rows
|
| 516 |
+
if isinstance(key[0],slice):
|
| 517 |
+
# Check bounds
|
| 518 |
+
if (key[0].start is None or key[0].start >= 0) and \
|
| 519 |
+
(key[0].stop is None or key[0].stop <= self.__rows+1):
|
| 520 |
+
# generate row indices
|
| 521 |
+
rows = xrange(*key[0].indices(self.__rows))
|
| 522 |
+
else:
|
| 523 |
+
raise IndexError('Row index out of bounds')
|
| 524 |
+
else:
|
| 525 |
+
# Single row
|
| 526 |
+
rows = [key[0]]
|
| 527 |
+
# Columns
|
| 528 |
+
if isinstance(key[1],slice):
|
| 529 |
+
# Check bounds
|
| 530 |
+
if (key[1].start is None or key[1].start >= 0) and \
|
| 531 |
+
(key[1].stop is None or key[1].stop <= self.__cols+1):
|
| 532 |
+
# Generate column indices
|
| 533 |
+
columns = xrange(*key[1].indices(self.__cols))
|
| 534 |
+
else:
|
| 535 |
+
raise IndexError('Column index out of bounds')
|
| 536 |
+
else:
|
| 537 |
+
# Single column
|
| 538 |
+
columns = [key[1]]
|
| 539 |
+
# Assign slice with a scalar
|
| 540 |
+
if isinstance(value,self.ctx.matrix):
|
| 541 |
+
# Assign elements to matrix if input and output dimensions match
|
| 542 |
+
if len(rows) == value.rows and len(columns) == value.cols:
|
| 543 |
+
for i,x in enumerate(rows):
|
| 544 |
+
for j,y in enumerate(columns):
|
| 545 |
+
self.__set_element((x,y), value.__get_element((i,j)))
|
| 546 |
+
else:
|
| 547 |
+
raise ValueError('Dimensions do not match')
|
| 548 |
+
else:
|
| 549 |
+
# Assign slice with scalars
|
| 550 |
+
value = self.ctx.convert(value)
|
| 551 |
+
for i in rows:
|
| 552 |
+
for j in columns:
|
| 553 |
+
self.__set_element((i,j), value)
|
| 554 |
+
else:
|
| 555 |
+
# Single element assingment
|
| 556 |
+
# Check bounds
|
| 557 |
+
if key[0] >= self.__rows or key[1] >= self.__cols:
|
| 558 |
+
raise IndexError('matrix index out of range')
|
| 559 |
+
# Convert and store value
|
| 560 |
+
value = self.ctx.convert(value)
|
| 561 |
+
if value: # only store non-zeros
|
| 562 |
+
self.__data[key] = value
|
| 563 |
+
elif key in self.__data:
|
| 564 |
+
del self.__data[key]
|
| 565 |
+
|
| 566 |
+
if self._LU:
|
| 567 |
+
self._LU = None
|
| 568 |
+
return
|
| 569 |
+
|
| 570 |
+
def __iter__(self):
|
| 571 |
+
for i in xrange(self.__rows):
|
| 572 |
+
for j in xrange(self.__cols):
|
| 573 |
+
yield self[i,j]
|
| 574 |
+
|
| 575 |
+
def __mul__(self, other):
|
| 576 |
+
if isinstance(other, self.ctx.matrix):
|
| 577 |
+
# dot multiplication
|
| 578 |
+
if self.__cols != other.__rows:
|
| 579 |
+
raise ValueError('dimensions not compatible for multiplication')
|
| 580 |
+
new = self.ctx.matrix(self.__rows, other.__cols)
|
| 581 |
+
self_zero = self.ctx.zero
|
| 582 |
+
self_get = self.__data.get
|
| 583 |
+
other_zero = other.ctx.zero
|
| 584 |
+
other_get = other.__data.get
|
| 585 |
+
for i in xrange(self.__rows):
|
| 586 |
+
for j in xrange(other.__cols):
|
| 587 |
+
new[i, j] = self.ctx.fdot((self_get((i,k), self_zero), other_get((k,j), other_zero))
|
| 588 |
+
for k in xrange(other.__rows))
|
| 589 |
+
return new
|
| 590 |
+
else:
|
| 591 |
+
# try scalar multiplication
|
| 592 |
+
new = self.ctx.matrix(self.__rows, self.__cols)
|
| 593 |
+
for i in xrange(self.__rows):
|
| 594 |
+
for j in xrange(self.__cols):
|
| 595 |
+
new[i, j] = other * self[i, j]
|
| 596 |
+
return new
|
| 597 |
+
|
| 598 |
+
def __matmul__(self, other):
|
| 599 |
+
return self.__mul__(other)
|
| 600 |
+
|
| 601 |
+
def __rmul__(self, other):
|
| 602 |
+
# assume other is scalar and thus commutative
|
| 603 |
+
if isinstance(other, self.ctx.matrix):
|
| 604 |
+
raise TypeError("other should not be type of ctx.matrix")
|
| 605 |
+
return self.__mul__(other)
|
| 606 |
+
|
| 607 |
+
def __pow__(self, other):
|
| 608 |
+
# avoid cyclic import problems
|
| 609 |
+
#from linalg import inverse
|
| 610 |
+
if not isinstance(other, int):
|
| 611 |
+
raise ValueError('only integer exponents are supported')
|
| 612 |
+
if not self.__rows == self.__cols:
|
| 613 |
+
raise ValueError('only powers of square matrices are defined')
|
| 614 |
+
n = other
|
| 615 |
+
if n == 0:
|
| 616 |
+
return self.ctx.eye(self.__rows)
|
| 617 |
+
if n < 0:
|
| 618 |
+
n = -n
|
| 619 |
+
neg = True
|
| 620 |
+
else:
|
| 621 |
+
neg = False
|
| 622 |
+
i = n
|
| 623 |
+
y = 1
|
| 624 |
+
z = self.copy()
|
| 625 |
+
while i != 0:
|
| 626 |
+
if i % 2 == 1:
|
| 627 |
+
y = y * z
|
| 628 |
+
z = z*z
|
| 629 |
+
i = i // 2
|
| 630 |
+
if neg:
|
| 631 |
+
y = self.ctx.inverse(y)
|
| 632 |
+
return y
|
| 633 |
+
|
| 634 |
+
def __div__(self, other):
|
| 635 |
+
# assume other is scalar and do element-wise divison
|
| 636 |
+
assert not isinstance(other, self.ctx.matrix)
|
| 637 |
+
new = self.ctx.matrix(self.__rows, self.__cols)
|
| 638 |
+
for i in xrange(self.__rows):
|
| 639 |
+
for j in xrange(self.__cols):
|
| 640 |
+
new[i,j] = self[i,j] / other
|
| 641 |
+
return new
|
| 642 |
+
|
| 643 |
+
__truediv__ = __div__
|
| 644 |
+
|
| 645 |
+
def __add__(self, other):
|
| 646 |
+
if isinstance(other, self.ctx.matrix):
|
| 647 |
+
if not (self.__rows == other.__rows and self.__cols == other.__cols):
|
| 648 |
+
raise ValueError('incompatible dimensions for addition')
|
| 649 |
+
new = self.ctx.matrix(self.__rows, self.__cols)
|
| 650 |
+
for i in xrange(self.__rows):
|
| 651 |
+
for j in xrange(self.__cols):
|
| 652 |
+
new[i,j] = self[i,j] + other[i,j]
|
| 653 |
+
return new
|
| 654 |
+
else:
|
| 655 |
+
# assume other is scalar and add element-wise
|
| 656 |
+
new = self.ctx.matrix(self.__rows, self.__cols)
|
| 657 |
+
for i in xrange(self.__rows):
|
| 658 |
+
for j in xrange(self.__cols):
|
| 659 |
+
new[i,j] += self[i,j] + other
|
| 660 |
+
return new
|
| 661 |
+
|
| 662 |
+
def __radd__(self, other):
|
| 663 |
+
return self.__add__(other)
|
| 664 |
+
|
| 665 |
+
def __sub__(self, other):
|
| 666 |
+
if isinstance(other, self.ctx.matrix) and not (self.__rows == other.__rows
|
| 667 |
+
and self.__cols == other.__cols):
|
| 668 |
+
raise ValueError('incompatible dimensions for subtraction')
|
| 669 |
+
return self.__add__(other * (-1))
|
| 670 |
+
|
| 671 |
+
def __pos__(self):
|
| 672 |
+
"""
|
| 673 |
+
+M returns a copy of M, rounded to current working precision.
|
| 674 |
+
"""
|
| 675 |
+
return (+1) * self
|
| 676 |
+
|
| 677 |
+
def __neg__(self):
|
| 678 |
+
return (-1) * self
|
| 679 |
+
|
| 680 |
+
def __rsub__(self, other):
|
| 681 |
+
return -self + other
|
| 682 |
+
|
| 683 |
+
def __eq__(self, other):
|
| 684 |
+
return self.__rows == other.__rows and self.__cols == other.__cols \
|
| 685 |
+
and self.__data == other.__data
|
| 686 |
+
|
| 687 |
+
def __len__(self):
|
| 688 |
+
if self.rows == 1:
|
| 689 |
+
return self.cols
|
| 690 |
+
elif self.cols == 1:
|
| 691 |
+
return self.rows
|
| 692 |
+
else:
|
| 693 |
+
return self.rows # do it like numpy
|
| 694 |
+
|
| 695 |
+
def __getrows(self):
|
| 696 |
+
return self.__rows
|
| 697 |
+
|
| 698 |
+
def __setrows(self, value):
|
| 699 |
+
for key in self.__data.copy():
|
| 700 |
+
if key[0] >= value:
|
| 701 |
+
del self.__data[key]
|
| 702 |
+
self.__rows = value
|
| 703 |
+
|
| 704 |
+
rows = property(__getrows, __setrows, doc='number of rows')
|
| 705 |
+
|
| 706 |
+
def __getcols(self):
|
| 707 |
+
return self.__cols
|
| 708 |
+
|
| 709 |
+
def __setcols(self, value):
|
| 710 |
+
for key in self.__data.copy():
|
| 711 |
+
if key[1] >= value:
|
| 712 |
+
del self.__data[key]
|
| 713 |
+
self.__cols = value
|
| 714 |
+
|
| 715 |
+
cols = property(__getcols, __setcols, doc='number of columns')
|
| 716 |
+
|
| 717 |
+
def transpose(self):
|
| 718 |
+
new = self.ctx.matrix(self.__cols, self.__rows)
|
| 719 |
+
for i in xrange(self.__rows):
|
| 720 |
+
for j in xrange(self.__cols):
|
| 721 |
+
new[j,i] = self[i,j]
|
| 722 |
+
return new
|
| 723 |
+
|
| 724 |
+
T = property(transpose)
|
| 725 |
+
|
| 726 |
+
def conjugate(self):
|
| 727 |
+
return self.apply(self.ctx.conj)
|
| 728 |
+
|
| 729 |
+
def transpose_conj(self):
|
| 730 |
+
return self.conjugate().transpose()
|
| 731 |
+
|
| 732 |
+
H = property(transpose_conj)
|
| 733 |
+
|
| 734 |
+
def copy(self):
|
| 735 |
+
new = self.ctx.matrix(self.__rows, self.__cols)
|
| 736 |
+
new.__data = self.__data.copy()
|
| 737 |
+
return new
|
| 738 |
+
|
| 739 |
+
__copy__ = copy
|
| 740 |
+
|
| 741 |
+
def column(self, n):
|
| 742 |
+
m = self.ctx.matrix(self.rows, 1)
|
| 743 |
+
for i in range(self.rows):
|
| 744 |
+
m[i] = self[i,n]
|
| 745 |
+
return m
|
| 746 |
+
|
| 747 |
+
class MatrixMethods(object):
|
| 748 |
+
|
| 749 |
+
def __init__(ctx):
|
| 750 |
+
# XXX: subclass
|
| 751 |
+
ctx.matrix = type('matrix', (_matrix,), {})
|
| 752 |
+
ctx.matrix.ctx = ctx
|
| 753 |
+
ctx.matrix.convert = ctx.convert
|
| 754 |
+
|
| 755 |
+
def eye(ctx, n, **kwargs):
|
| 756 |
+
"""
|
| 757 |
+
Create square identity matrix n x n.
|
| 758 |
+
"""
|
| 759 |
+
A = ctx.matrix(n, **kwargs)
|
| 760 |
+
for i in xrange(n):
|
| 761 |
+
A[i,i] = 1
|
| 762 |
+
return A
|
| 763 |
+
|
| 764 |
+
def diag(ctx, diagonal, **kwargs):
|
| 765 |
+
"""
|
| 766 |
+
Create square diagonal matrix using given list.
|
| 767 |
+
|
| 768 |
+
Example:
|
| 769 |
+
>>> from mpmath import diag, mp
|
| 770 |
+
>>> mp.pretty = False
|
| 771 |
+
>>> diag([1, 2, 3])
|
| 772 |
+
matrix(
|
| 773 |
+
[['1.0', '0.0', '0.0'],
|
| 774 |
+
['0.0', '2.0', '0.0'],
|
| 775 |
+
['0.0', '0.0', '3.0']])
|
| 776 |
+
"""
|
| 777 |
+
A = ctx.matrix(len(diagonal), **kwargs)
|
| 778 |
+
for i in xrange(len(diagonal)):
|
| 779 |
+
A[i,i] = diagonal[i]
|
| 780 |
+
return A
|
| 781 |
+
|
| 782 |
+
def zeros(ctx, *args, **kwargs):
|
| 783 |
+
"""
|
| 784 |
+
Create matrix m x n filled with zeros.
|
| 785 |
+
One given dimension will create square matrix n x n.
|
| 786 |
+
|
| 787 |
+
Example:
|
| 788 |
+
>>> from mpmath import zeros, mp
|
| 789 |
+
>>> mp.pretty = False
|
| 790 |
+
>>> zeros(2)
|
| 791 |
+
matrix(
|
| 792 |
+
[['0.0', '0.0'],
|
| 793 |
+
['0.0', '0.0']])
|
| 794 |
+
"""
|
| 795 |
+
if len(args) == 1:
|
| 796 |
+
m = n = args[0]
|
| 797 |
+
elif len(args) == 2:
|
| 798 |
+
m = args[0]
|
| 799 |
+
n = args[1]
|
| 800 |
+
else:
|
| 801 |
+
raise TypeError('zeros expected at most 2 arguments, got %i' % len(args))
|
| 802 |
+
A = ctx.matrix(m, n, **kwargs)
|
| 803 |
+
for i in xrange(m):
|
| 804 |
+
for j in xrange(n):
|
| 805 |
+
A[i,j] = 0
|
| 806 |
+
return A
|
| 807 |
+
|
| 808 |
+
def ones(ctx, *args, **kwargs):
|
| 809 |
+
"""
|
| 810 |
+
Create matrix m x n filled with ones.
|
| 811 |
+
One given dimension will create square matrix n x n.
|
| 812 |
+
|
| 813 |
+
Example:
|
| 814 |
+
>>> from mpmath import ones, mp
|
| 815 |
+
>>> mp.pretty = False
|
| 816 |
+
>>> ones(2)
|
| 817 |
+
matrix(
|
| 818 |
+
[['1.0', '1.0'],
|
| 819 |
+
['1.0', '1.0']])
|
| 820 |
+
"""
|
| 821 |
+
if len(args) == 1:
|
| 822 |
+
m = n = args[0]
|
| 823 |
+
elif len(args) == 2:
|
| 824 |
+
m = args[0]
|
| 825 |
+
n = args[1]
|
| 826 |
+
else:
|
| 827 |
+
raise TypeError('ones expected at most 2 arguments, got %i' % len(args))
|
| 828 |
+
A = ctx.matrix(m, n, **kwargs)
|
| 829 |
+
for i in xrange(m):
|
| 830 |
+
for j in xrange(n):
|
| 831 |
+
A[i,j] = 1
|
| 832 |
+
return A
|
| 833 |
+
|
| 834 |
+
def hilbert(ctx, m, n=None):
|
| 835 |
+
"""
|
| 836 |
+
Create (pseudo) hilbert matrix m x n.
|
| 837 |
+
One given dimension will create hilbert matrix n x n.
|
| 838 |
+
|
| 839 |
+
The matrix is very ill-conditioned and symmetric, positive definite if
|
| 840 |
+
square.
|
| 841 |
+
"""
|
| 842 |
+
if n is None:
|
| 843 |
+
n = m
|
| 844 |
+
A = ctx.matrix(m, n)
|
| 845 |
+
for i in xrange(m):
|
| 846 |
+
for j in xrange(n):
|
| 847 |
+
A[i,j] = ctx.one / (i + j + 1)
|
| 848 |
+
return A
|
| 849 |
+
|
| 850 |
+
def randmatrix(ctx, m, n=None, min=0, max=1, **kwargs):
|
| 851 |
+
"""
|
| 852 |
+
Create a random m x n matrix.
|
| 853 |
+
|
| 854 |
+
All values are >= min and <max.
|
| 855 |
+
n defaults to m.
|
| 856 |
+
|
| 857 |
+
Example:
|
| 858 |
+
>>> from mpmath import randmatrix
|
| 859 |
+
>>> randmatrix(2) # doctest:+SKIP
|
| 860 |
+
matrix(
|
| 861 |
+
[['0.53491598236191806', '0.57195669543302752'],
|
| 862 |
+
['0.85589992269513615', '0.82444367501382143']])
|
| 863 |
+
"""
|
| 864 |
+
if not n:
|
| 865 |
+
n = m
|
| 866 |
+
A = ctx.matrix(m, n, **kwargs)
|
| 867 |
+
for i in xrange(m):
|
| 868 |
+
for j in xrange(n):
|
| 869 |
+
A[i,j] = ctx.rand() * (max - min) + min
|
| 870 |
+
return A
|
| 871 |
+
|
| 872 |
+
def swap_row(ctx, A, i, j):
|
| 873 |
+
"""
|
| 874 |
+
Swap row i with row j.
|
| 875 |
+
"""
|
| 876 |
+
if i == j:
|
| 877 |
+
return
|
| 878 |
+
if isinstance(A, ctx.matrix):
|
| 879 |
+
for k in xrange(A.cols):
|
| 880 |
+
A[i,k], A[j,k] = A[j,k], A[i,k]
|
| 881 |
+
elif isinstance(A, list):
|
| 882 |
+
A[i], A[j] = A[j], A[i]
|
| 883 |
+
else:
|
| 884 |
+
raise TypeError('could not interpret type')
|
| 885 |
+
|
| 886 |
+
def extend(ctx, A, b):
|
| 887 |
+
"""
|
| 888 |
+
Extend matrix A with column b and return result.
|
| 889 |
+
"""
|
| 890 |
+
if not isinstance(A, ctx.matrix):
|
| 891 |
+
raise TypeError("A should be a type of ctx.matrix")
|
| 892 |
+
if A.rows != len(b):
|
| 893 |
+
raise ValueError("Value should be equal to len(b)")
|
| 894 |
+
A = A.copy()
|
| 895 |
+
A.cols += 1
|
| 896 |
+
for i in xrange(A.rows):
|
| 897 |
+
A[i, A.cols-1] = b[i]
|
| 898 |
+
return A
|
| 899 |
+
|
| 900 |
+
def norm(ctx, x, p=2):
|
| 901 |
+
r"""
|
| 902 |
+
Gives the entrywise `p`-norm of an iterable *x*, i.e. the vector norm
|
| 903 |
+
`\left(\sum_k |x_k|^p\right)^{1/p}`, for any given `1 \le p \le \infty`.
|
| 904 |
+
|
| 905 |
+
Special cases:
|
| 906 |
+
|
| 907 |
+
If *x* is not iterable, this just returns ``absmax(x)``.
|
| 908 |
+
|
| 909 |
+
``p=1`` gives the sum of absolute values.
|
| 910 |
+
|
| 911 |
+
``p=2`` is the standard Euclidean vector norm.
|
| 912 |
+
|
| 913 |
+
``p=inf`` gives the magnitude of the largest element.
|
| 914 |
+
|
| 915 |
+
For *x* a matrix, ``p=2`` is the Frobenius norm.
|
| 916 |
+
For operator matrix norms, use :func:`~mpmath.mnorm` instead.
|
| 917 |
+
|
| 918 |
+
You can use the string 'inf' as well as float('inf') or mpf('inf')
|
| 919 |
+
to specify the infinity norm.
|
| 920 |
+
|
| 921 |
+
**Examples**
|
| 922 |
+
|
| 923 |
+
>>> from mpmath import *
|
| 924 |
+
>>> mp.dps = 15; mp.pretty = False
|
| 925 |
+
>>> x = matrix([-10, 2, 100])
|
| 926 |
+
>>> norm(x, 1)
|
| 927 |
+
mpf('112.0')
|
| 928 |
+
>>> norm(x, 2)
|
| 929 |
+
mpf('100.5186549850325')
|
| 930 |
+
>>> norm(x, inf)
|
| 931 |
+
mpf('100.0')
|
| 932 |
+
|
| 933 |
+
"""
|
| 934 |
+
try:
|
| 935 |
+
iter(x)
|
| 936 |
+
except TypeError:
|
| 937 |
+
return ctx.absmax(x)
|
| 938 |
+
if type(p) is not int:
|
| 939 |
+
p = ctx.convert(p)
|
| 940 |
+
if p == ctx.inf:
|
| 941 |
+
return max(ctx.absmax(i) for i in x)
|
| 942 |
+
elif p == 1:
|
| 943 |
+
return ctx.fsum(x, absolute=1)
|
| 944 |
+
elif p == 2:
|
| 945 |
+
return ctx.sqrt(ctx.fsum(x, absolute=1, squared=1))
|
| 946 |
+
elif p > 1:
|
| 947 |
+
return ctx.nthroot(ctx.fsum(abs(i)**p for i in x), p)
|
| 948 |
+
else:
|
| 949 |
+
raise ValueError('p has to be >= 1')
|
| 950 |
+
|
| 951 |
+
def mnorm(ctx, A, p=1):
|
| 952 |
+
r"""
|
| 953 |
+
Gives the matrix (operator) `p`-norm of A. Currently ``p=1`` and ``p=inf``
|
| 954 |
+
are supported:
|
| 955 |
+
|
| 956 |
+
``p=1`` gives the 1-norm (maximal column sum)
|
| 957 |
+
|
| 958 |
+
``p=inf`` gives the `\infty`-norm (maximal row sum).
|
| 959 |
+
You can use the string 'inf' as well as float('inf') or mpf('inf')
|
| 960 |
+
|
| 961 |
+
``p=2`` (not implemented) for a square matrix is the usual spectral
|
| 962 |
+
matrix norm, i.e. the largest singular value.
|
| 963 |
+
|
| 964 |
+
``p='f'`` (or 'F', 'fro', 'Frobenius, 'frobenius') gives the
|
| 965 |
+
Frobenius norm, which is the elementwise 2-norm. The Frobenius norm is an
|
| 966 |
+
approximation of the spectral norm and satisfies
|
| 967 |
+
|
| 968 |
+
.. math ::
|
| 969 |
+
|
| 970 |
+
\frac{1}{\sqrt{\mathrm{rank}(A)}} \|A\|_F \le \|A\|_2 \le \|A\|_F
|
| 971 |
+
|
| 972 |
+
The Frobenius norm lacks some mathematical properties that might
|
| 973 |
+
be expected of a norm.
|
| 974 |
+
|
| 975 |
+
For general elementwise `p`-norms, use :func:`~mpmath.norm` instead.
|
| 976 |
+
|
| 977 |
+
**Examples**
|
| 978 |
+
|
| 979 |
+
>>> from mpmath import *
|
| 980 |
+
>>> mp.dps = 15; mp.pretty = False
|
| 981 |
+
>>> A = matrix([[1, -1000], [100, 50]])
|
| 982 |
+
>>> mnorm(A, 1)
|
| 983 |
+
mpf('1050.0')
|
| 984 |
+
>>> mnorm(A, inf)
|
| 985 |
+
mpf('1001.0')
|
| 986 |
+
>>> mnorm(A, 'F')
|
| 987 |
+
mpf('1006.2310867787777')
|
| 988 |
+
|
| 989 |
+
"""
|
| 990 |
+
A = ctx.matrix(A)
|
| 991 |
+
if type(p) is not int:
|
| 992 |
+
if type(p) is str and 'frobenius'.startswith(p.lower()):
|
| 993 |
+
return ctx.norm(A, 2)
|
| 994 |
+
p = ctx.convert(p)
|
| 995 |
+
m, n = A.rows, A.cols
|
| 996 |
+
if p == 1:
|
| 997 |
+
return max(ctx.fsum((A[i,j] for i in xrange(m)), absolute=1) for j in xrange(n))
|
| 998 |
+
elif p == ctx.inf:
|
| 999 |
+
return max(ctx.fsum((A[i,j] for j in xrange(n)), absolute=1) for i in xrange(m))
|
| 1000 |
+
else:
|
| 1001 |
+
raise NotImplementedError("matrix p-norm for arbitrary p")
|
| 1002 |
+
|
| 1003 |
+
if __name__ == '__main__':
|
| 1004 |
+
import doctest
|
| 1005 |
+
doctest.testmod()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (214 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublasLt.h
ADDED
|
@@ -0,0 +1,1853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2022 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
#pragma once
|
| 50 |
+
|
| 51 |
+
#ifndef CUBLASAPI
|
| 52 |
+
#ifdef __CUDACC__
|
| 53 |
+
#define CUBLASAPI __host__ __device__
|
| 54 |
+
#else
|
| 55 |
+
#define CUBLASAPI
|
| 56 |
+
#endif
|
| 57 |
+
#endif
|
| 58 |
+
|
| 59 |
+
#include <cublas_api.h>
|
| 60 |
+
|
| 61 |
+
#include <stdint.h>
|
| 62 |
+
#include <stddef.h>
|
| 63 |
+
#include <stdio.h>
|
| 64 |
+
|
| 65 |
+
#if defined(__cplusplus)
|
| 66 |
+
extern "C" {
|
| 67 |
+
#endif /* __cplusplus */
|
| 68 |
+
|
| 69 |
+
/** Opaque structure holding CUBLASLT context
|
| 70 |
+
*/
|
| 71 |
+
typedef struct cublasLtContext* cublasLtHandle_t;
|
| 72 |
+
|
| 73 |
+
cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t* lightHandle);
|
| 74 |
+
|
| 75 |
+
cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle);
|
| 76 |
+
|
| 77 |
+
const char* CUBLASWINAPI cublasLtGetStatusName(cublasStatus_t status);
|
| 78 |
+
|
| 79 |
+
const char* CUBLASWINAPI cublasLtGetStatusString(cublasStatus_t status);
|
| 80 |
+
|
| 81 |
+
size_t CUBLASWINAPI cublasLtGetVersion(void);
|
| 82 |
+
|
| 83 |
+
size_t CUBLASWINAPI cublasLtGetCudartVersion(void);
|
| 84 |
+
|
| 85 |
+
cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type, int* value);
|
| 86 |
+
|
| 87 |
+
cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheGetCapacity(size_t* capacity);
|
| 88 |
+
cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheSetCapacity(size_t capacity);
|
| 89 |
+
|
| 90 |
+
/** Semi-opaque descriptor for matrix memory layout
|
| 91 |
+
*/
|
| 92 |
+
typedef struct {
|
| 93 |
+
uint64_t data[8];
|
| 94 |
+
} cublasLtMatrixLayoutOpaque_t;
|
| 95 |
+
|
| 96 |
+
/** Opaque descriptor for matrix memory layout
|
| 97 |
+
*/
|
| 98 |
+
typedef cublasLtMatrixLayoutOpaque_t* cublasLtMatrixLayout_t;
|
| 99 |
+
|
| 100 |
+
/** Semi-opaque algorithm descriptor (to avoid complicated alloc/free schemes)
|
| 101 |
+
*
|
| 102 |
+
* This structure can be trivially serialized and later restored for use with the same version of cuBLAS library to save
|
| 103 |
+
* on selecting the right configuration again.
|
| 104 |
+
*/
|
| 105 |
+
typedef struct {
|
| 106 |
+
uint64_t data[8];
|
| 107 |
+
} cublasLtMatmulAlgo_t;
|
| 108 |
+
|
| 109 |
+
/** Semi-opaque descriptor for cublasLtMatmul() operation details
|
| 110 |
+
*/
|
| 111 |
+
typedef struct {
|
| 112 |
+
uint64_t data[23];
|
| 113 |
+
} cublasLtMatmulDescOpaque_t;
|
| 114 |
+
|
| 115 |
+
/** Opaque descriptor for cublasLtMatmul() operation details
|
| 116 |
+
*/
|
| 117 |
+
typedef cublasLtMatmulDescOpaque_t* cublasLtMatmulDesc_t;
|
| 118 |
+
|
| 119 |
+
/** Semi-opaque descriptor for cublasLtMatrixTransform() operation details
|
| 120 |
+
*/
|
| 121 |
+
typedef struct {
|
| 122 |
+
uint64_t data[8];
|
| 123 |
+
} cublasLtMatrixTransformDescOpaque_t;
|
| 124 |
+
|
| 125 |
+
/** Opaque descriptor for cublasLtMatrixTransform() operation details
|
| 126 |
+
*/
|
| 127 |
+
typedef cublasLtMatrixTransformDescOpaque_t* cublasLtMatrixTransformDesc_t;
|
| 128 |
+
|
| 129 |
+
/** Semi-opaque descriptor for cublasLtMatmulPreference() operation details
|
| 130 |
+
*/
|
| 131 |
+
typedef struct {
|
| 132 |
+
uint64_t data[10];
|
| 133 |
+
} cublasLtMatmulPreferenceOpaque_t;
|
| 134 |
+
|
| 135 |
+
/** Opaque descriptor for cublasLtMatmulAlgoGetHeuristic() configuration
|
| 136 |
+
*/
|
| 137 |
+
typedef cublasLtMatmulPreferenceOpaque_t* cublasLtMatmulPreference_t;
|
| 138 |
+
|
| 139 |
+
/** Tile size (in C/D matrix Rows x Cols)
|
| 140 |
+
*
|
| 141 |
+
* General order of tile IDs is sorted by size first and by first dimension second.
|
| 142 |
+
*/
|
| 143 |
+
typedef enum {
|
| 144 |
+
CUBLASLT_MATMUL_TILE_UNDEFINED = 0,
|
| 145 |
+
CUBLASLT_MATMUL_TILE_8x8 = 1,
|
| 146 |
+
CUBLASLT_MATMUL_TILE_8x16 = 2,
|
| 147 |
+
CUBLASLT_MATMUL_TILE_16x8 = 3,
|
| 148 |
+
CUBLASLT_MATMUL_TILE_8x32 = 4,
|
| 149 |
+
CUBLASLT_MATMUL_TILE_16x16 = 5,
|
| 150 |
+
CUBLASLT_MATMUL_TILE_32x8 = 6,
|
| 151 |
+
CUBLASLT_MATMUL_TILE_8x64 = 7,
|
| 152 |
+
CUBLASLT_MATMUL_TILE_16x32 = 8,
|
| 153 |
+
CUBLASLT_MATMUL_TILE_32x16 = 9,
|
| 154 |
+
CUBLASLT_MATMUL_TILE_64x8 = 10,
|
| 155 |
+
CUBLASLT_MATMUL_TILE_32x32 = 11,
|
| 156 |
+
CUBLASLT_MATMUL_TILE_32x64 = 12,
|
| 157 |
+
CUBLASLT_MATMUL_TILE_64x32 = 13,
|
| 158 |
+
CUBLASLT_MATMUL_TILE_32x128 = 14,
|
| 159 |
+
CUBLASLT_MATMUL_TILE_64x64 = 15,
|
| 160 |
+
CUBLASLT_MATMUL_TILE_128x32 = 16,
|
| 161 |
+
CUBLASLT_MATMUL_TILE_64x128 = 17,
|
| 162 |
+
CUBLASLT_MATMUL_TILE_128x64 = 18,
|
| 163 |
+
CUBLASLT_MATMUL_TILE_64x256 = 19,
|
| 164 |
+
CUBLASLT_MATMUL_TILE_128x128 = 20,
|
| 165 |
+
CUBLASLT_MATMUL_TILE_256x64 = 21,
|
| 166 |
+
CUBLASLT_MATMUL_TILE_64x512 = 22,
|
| 167 |
+
CUBLASLT_MATMUL_TILE_128x256 = 23,
|
| 168 |
+
CUBLASLT_MATMUL_TILE_256x128 = 24,
|
| 169 |
+
CUBLASLT_MATMUL_TILE_512x64 = 25,
|
| 170 |
+
CUBLASLT_MATMUL_TILE_64x96 = 26,
|
| 171 |
+
CUBLASLT_MATMUL_TILE_96x64 = 27,
|
| 172 |
+
CUBLASLT_MATMUL_TILE_96x128 = 28,
|
| 173 |
+
CUBLASLT_MATMUL_TILE_128x160 = 29,
|
| 174 |
+
CUBLASLT_MATMUL_TILE_160x128 = 30,
|
| 175 |
+
CUBLASLT_MATMUL_TILE_192x128 = 31,
|
| 176 |
+
CUBLASLT_MATMUL_TILE_128x192 = 32,
|
| 177 |
+
CUBLASLT_MATMUL_TILE_128x96 = 33,
|
| 178 |
+
CUBLASLT_MATMUL_TILE_END
|
| 179 |
+
} cublasLtMatmulTile_t;
|
| 180 |
+
|
| 181 |
+
/** Size and number of stages in which elements are read into shared memory
|
| 182 |
+
*
|
| 183 |
+
* General order of stages IDs is sorted by stage size first and by number of stages second.
|
| 184 |
+
*/
|
| 185 |
+
typedef enum {
|
| 186 |
+
CUBLASLT_MATMUL_STAGES_UNDEFINED = 0,
|
| 187 |
+
CUBLASLT_MATMUL_STAGES_16x1 = 1,
|
| 188 |
+
CUBLASLT_MATMUL_STAGES_16x2 = 2,
|
| 189 |
+
CUBLASLT_MATMUL_STAGES_16x3 = 3,
|
| 190 |
+
CUBLASLT_MATMUL_STAGES_16x4 = 4,
|
| 191 |
+
CUBLASLT_MATMUL_STAGES_16x5 = 5,
|
| 192 |
+
CUBLASLT_MATMUL_STAGES_16x6 = 6,
|
| 193 |
+
CUBLASLT_MATMUL_STAGES_32x1 = 7,
|
| 194 |
+
CUBLASLT_MATMUL_STAGES_32x2 = 8,
|
| 195 |
+
CUBLASLT_MATMUL_STAGES_32x3 = 9,
|
| 196 |
+
CUBLASLT_MATMUL_STAGES_32x4 = 10,
|
| 197 |
+
CUBLASLT_MATMUL_STAGES_32x5 = 11,
|
| 198 |
+
CUBLASLT_MATMUL_STAGES_32x6 = 12,
|
| 199 |
+
CUBLASLT_MATMUL_STAGES_64x1 = 13,
|
| 200 |
+
CUBLASLT_MATMUL_STAGES_64x2 = 14,
|
| 201 |
+
CUBLASLT_MATMUL_STAGES_64x3 = 15,
|
| 202 |
+
CUBLASLT_MATMUL_STAGES_64x4 = 16,
|
| 203 |
+
CUBLASLT_MATMUL_STAGES_64x5 = 17,
|
| 204 |
+
CUBLASLT_MATMUL_STAGES_64x6 = 18,
|
| 205 |
+
CUBLASLT_MATMUL_STAGES_128x1 = 19,
|
| 206 |
+
CUBLASLT_MATMUL_STAGES_128x2 = 20,
|
| 207 |
+
CUBLASLT_MATMUL_STAGES_128x3 = 21,
|
| 208 |
+
CUBLASLT_MATMUL_STAGES_128x4 = 22,
|
| 209 |
+
CUBLASLT_MATMUL_STAGES_128x5 = 23,
|
| 210 |
+
CUBLASLT_MATMUL_STAGES_128x6 = 24,
|
| 211 |
+
CUBLASLT_MATMUL_STAGES_32x10 = 25,
|
| 212 |
+
CUBLASLT_MATMUL_STAGES_8x4 = 26,
|
| 213 |
+
CUBLASLT_MATMUL_STAGES_16x10 = 27,
|
| 214 |
+
CUBLASLT_MATMUL_STAGES_8x5 = 28,
|
| 215 |
+
CUBLASLT_MATMUL_STAGES_16x80 = 29,
|
| 216 |
+
CUBLASLT_MATMUL_STAGES_64x80 = 30,
|
| 217 |
+
CUBLASLT_MATMUL_STAGES_8x3 = 31,
|
| 218 |
+
CUBLASLT_MATMUL_STAGES_8xAUTO = 32,
|
| 219 |
+
CUBLASLT_MATMUL_STAGES_16xAUTO = 33,
|
| 220 |
+
CUBLASLT_MATMUL_STAGES_32xAUTO = 34,
|
| 221 |
+
CUBLASLT_MATMUL_STAGES_64xAUTO = 35,
|
| 222 |
+
CUBLASLT_MATMUL_STAGES_128xAUTO = 36,
|
| 223 |
+
CUBLASLT_MATMUL_STAGES_END
|
| 224 |
+
} cublasLtMatmulStages_t;
|
| 225 |
+
|
| 226 |
+
/** Thread Block Cluster size
|
| 227 |
+
*
|
| 228 |
+
* Typically dimensioned similar to cublasLtMatmulTile_t, with the third coordinate unused at this time.
|
| 229 |
+
*/
|
| 230 |
+
typedef enum {
|
| 231 |
+
/** Let library pick cluster shape automatically */
|
| 232 |
+
CUBLASLT_CLUSTER_SHAPE_AUTO = 0,
|
| 233 |
+
CUBLASLT_CLUSTER_SHAPE_1x1x1 = 2,
|
| 234 |
+
CUBLASLT_CLUSTER_SHAPE_2x1x1 = 3,
|
| 235 |
+
CUBLASLT_CLUSTER_SHAPE_4x1x1 = 4,
|
| 236 |
+
CUBLASLT_CLUSTER_SHAPE_1x2x1 = 5,
|
| 237 |
+
CUBLASLT_CLUSTER_SHAPE_2x2x1 = 6,
|
| 238 |
+
CUBLASLT_CLUSTER_SHAPE_4x2x1 = 7,
|
| 239 |
+
CUBLASLT_CLUSTER_SHAPE_1x4x1 = 8,
|
| 240 |
+
CUBLASLT_CLUSTER_SHAPE_2x4x1 = 9,
|
| 241 |
+
CUBLASLT_CLUSTER_SHAPE_4x4x1 = 10,
|
| 242 |
+
CUBLASLT_CLUSTER_SHAPE_8x1x1 = 11,
|
| 243 |
+
CUBLASLT_CLUSTER_SHAPE_1x8x1 = 12,
|
| 244 |
+
CUBLASLT_CLUSTER_SHAPE_8x2x1 = 13,
|
| 245 |
+
CUBLASLT_CLUSTER_SHAPE_2x8x1 = 14,
|
| 246 |
+
CUBLASLT_CLUSTER_SHAPE_16x1x1 = 15,
|
| 247 |
+
CUBLASLT_CLUSTER_SHAPE_1x16x1 = 16,
|
| 248 |
+
CUBLASLT_CLUSTER_SHAPE_3x1x1 = 17,
|
| 249 |
+
CUBLASLT_CLUSTER_SHAPE_5x1x1 = 18,
|
| 250 |
+
CUBLASLT_CLUSTER_SHAPE_6x1x1 = 19,
|
| 251 |
+
CUBLASLT_CLUSTER_SHAPE_7x1x1 = 20,
|
| 252 |
+
CUBLASLT_CLUSTER_SHAPE_9x1x1 = 21,
|
| 253 |
+
CUBLASLT_CLUSTER_SHAPE_10x1x1 = 22,
|
| 254 |
+
CUBLASLT_CLUSTER_SHAPE_11x1x1 = 23,
|
| 255 |
+
CUBLASLT_CLUSTER_SHAPE_12x1x1 = 24,
|
| 256 |
+
CUBLASLT_CLUSTER_SHAPE_13x1x1 = 25,
|
| 257 |
+
CUBLASLT_CLUSTER_SHAPE_14x1x1 = 26,
|
| 258 |
+
CUBLASLT_CLUSTER_SHAPE_15x1x1 = 27,
|
| 259 |
+
CUBLASLT_CLUSTER_SHAPE_3x2x1 = 28,
|
| 260 |
+
CUBLASLT_CLUSTER_SHAPE_5x2x1 = 29,
|
| 261 |
+
CUBLASLT_CLUSTER_SHAPE_6x2x1 = 30,
|
| 262 |
+
CUBLASLT_CLUSTER_SHAPE_7x2x1 = 31,
|
| 263 |
+
CUBLASLT_CLUSTER_SHAPE_1x3x1 = 32,
|
| 264 |
+
CUBLASLT_CLUSTER_SHAPE_2x3x1 = 33,
|
| 265 |
+
CUBLASLT_CLUSTER_SHAPE_3x3x1 = 34,
|
| 266 |
+
CUBLASLT_CLUSTER_SHAPE_4x3x1 = 35,
|
| 267 |
+
CUBLASLT_CLUSTER_SHAPE_5x3x1 = 36,
|
| 268 |
+
CUBLASLT_CLUSTER_SHAPE_3x4x1 = 37,
|
| 269 |
+
CUBLASLT_CLUSTER_SHAPE_1x5x1 = 38,
|
| 270 |
+
CUBLASLT_CLUSTER_SHAPE_2x5x1 = 39,
|
| 271 |
+
CUBLASLT_CLUSTER_SHAPE_3x5x1 = 40,
|
| 272 |
+
CUBLASLT_CLUSTER_SHAPE_1x6x1 = 41,
|
| 273 |
+
CUBLASLT_CLUSTER_SHAPE_2x6x1 = 42,
|
| 274 |
+
CUBLASLT_CLUSTER_SHAPE_1x7x1 = 43,
|
| 275 |
+
CUBLASLT_CLUSTER_SHAPE_2x7x1 = 44,
|
| 276 |
+
CUBLASLT_CLUSTER_SHAPE_1x9x1 = 45,
|
| 277 |
+
CUBLASLT_CLUSTER_SHAPE_1x10x1 = 46,
|
| 278 |
+
CUBLASLT_CLUSTER_SHAPE_1x11x1 = 47,
|
| 279 |
+
CUBLASLT_CLUSTER_SHAPE_1x12x1 = 48,
|
| 280 |
+
CUBLASLT_CLUSTER_SHAPE_1x13x1 = 49,
|
| 281 |
+
CUBLASLT_CLUSTER_SHAPE_1x14x1 = 50,
|
| 282 |
+
CUBLASLT_CLUSTER_SHAPE_1x15x1 = 51,
|
| 283 |
+
CUBLASLT_CLUSTER_SHAPE_END
|
| 284 |
+
} cublasLtClusterShape_t;
|
| 285 |
+
|
| 286 |
+
/** Inner size of the kernel
|
| 287 |
+
*
|
| 288 |
+
* Represents various aspects of internal kernel design, that don't impact CUDA grid size but may have other more subtle
|
| 289 |
+
* effects.
|
| 290 |
+
*
|
| 291 |
+
*/
|
| 292 |
+
typedef enum {
|
| 293 |
+
CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED = 0,
|
| 294 |
+
CUBLASLT_MATMUL_INNER_SHAPE_MMA884 = 1,
|
| 295 |
+
CUBLASLT_MATMUL_INNER_SHAPE_MMA1684 = 2,
|
| 296 |
+
CUBLASLT_MATMUL_INNER_SHAPE_MMA1688 = 3,
|
| 297 |
+
CUBLASLT_MATMUL_INNER_SHAPE_MMA16816 = 4,
|
| 298 |
+
CUBLASLT_MATMUL_INNER_SHAPE_END
|
| 299 |
+
} cublasLtMatmulInnerShape_t;
|
| 300 |
+
|
| 301 |
+
/** Pointer mode to use for alpha/beta */
|
| 302 |
+
typedef enum {
|
| 303 |
+
/** matches CUBLAS_POINTER_MODE_HOST, pointer targets a single value host memory */
|
| 304 |
+
CUBLASLT_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST,
|
| 305 |
+
/** matches CUBLAS_POINTER_MODE_DEVICE, pointer targets a single value device memory */
|
| 306 |
+
CUBLASLT_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE,
|
| 307 |
+
/** pointer targets an array in device memory */
|
| 308 |
+
CUBLASLT_POINTER_MODE_DEVICE_VECTOR = 2,
|
| 309 |
+
/** alpha pointer targets an array in device memory, beta is zero. Note:
|
| 310 |
+
CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is not supported, must be 0. */
|
| 311 |
+
CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO = 3,
|
| 312 |
+
/** alpha pointer targets an array in device memory, beta is a single value in host memory. */
|
| 313 |
+
CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST = 4,
|
| 314 |
+
} cublasLtPointerMode_t;
|
| 315 |
+
|
| 316 |
+
/** Mask to define and query pointer mode capability */
|
| 317 |
+
typedef enum {
|
| 318 |
+
/** no initial filtering is performed when querying pointer mode capabilities, will use gemm pointer mode defined in
|
| 319 |
+
operation description **/
|
| 320 |
+
CUBLASLT_POINTER_MODE_MASK_NO_FILTERING = 0,
|
| 321 |
+
/** see CUBLASLT_POINTER_MODE_HOST */
|
| 322 |
+
CUBLASLT_POINTER_MODE_MASK_HOST = 1,
|
| 323 |
+
/** see CUBLASLT_POINTER_MODE_DEVICE */
|
| 324 |
+
CUBLASLT_POINTER_MODE_MASK_DEVICE = 2,
|
| 325 |
+
/** see CUBLASLT_POINTER_MODE_DEVICE_VECTOR */
|
| 326 |
+
CUBLASLT_POINTER_MODE_MASK_DEVICE_VECTOR = 4,
|
| 327 |
+
/** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO */
|
| 328 |
+
CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_ZERO = 8,
|
| 329 |
+
/** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST */
|
| 330 |
+
CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_HOST = 16,
|
| 331 |
+
} cublasLtPointerModeMask_t;
|
| 332 |
+
|
| 333 |
+
/** Implementation details that may affect numerical behavior of algorithms. */
|
| 334 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA (0x01ull << 0)
|
| 335 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA (0x02ull << 0)
|
| 336 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_IMMA (0x04ull << 0)
|
| 337 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_DMMA (0x08ull << 0)
|
| 338 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK (0xfeull << 0)
|
| 339 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK (0xffull << 0)
|
| 340 |
+
|
| 341 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_16F (0x01ull << 8)
|
| 342 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32F (0x02ull << 8)
|
| 343 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_64F (0x04ull << 8)
|
| 344 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32I (0x08ull << 8)
|
| 345 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_TYPE_MASK (0xffull << 8)
|
| 346 |
+
|
| 347 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16F (0x01ull << 16)
|
| 348 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16BF (0x02ull << 16)
|
| 349 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_TF32 (0x04ull << 16)
|
| 350 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_32F (0x08ull << 16)
|
| 351 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_64F (0x10ull << 16)
|
| 352 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8I (0x20ull << 16)
|
| 353 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E4M3 (0x40ull << 16)
|
| 354 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E5M2 (0x80ull << 16)
|
| 355 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK (0xffull << 16)
|
| 356 |
+
|
| 357 |
+
#define CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN (0x01ull << 32)
|
| 358 |
+
typedef uint64_t cublasLtNumericalImplFlags_t;
|
| 359 |
+
|
| 360 |
+
/** Execute matrix multiplication (D = alpha * op(A) * op(B) + beta * C).
|
| 361 |
+
*
|
| 362 |
+
* \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized
|
| 363 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g.
|
| 364 |
+
* when workspaceSizeInBytes is less than workspace required by configured
|
| 365 |
+
* algo
|
| 366 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured
|
| 367 |
+
* operation
|
| 368 |
+
* \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device
|
| 369 |
+
* \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device
|
| 370 |
+
* \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully
|
| 371 |
+
*/
|
| 372 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmul(cublasLtHandle_t lightHandle,
|
| 373 |
+
cublasLtMatmulDesc_t computeDesc,
|
| 374 |
+
const void* alpha, /* host or device pointer */
|
| 375 |
+
const void* A,
|
| 376 |
+
cublasLtMatrixLayout_t Adesc,
|
| 377 |
+
const void* B,
|
| 378 |
+
cublasLtMatrixLayout_t Bdesc,
|
| 379 |
+
const void* beta, /* host or device pointer */
|
| 380 |
+
const void* C,
|
| 381 |
+
cublasLtMatrixLayout_t Cdesc,
|
| 382 |
+
void* D,
|
| 383 |
+
cublasLtMatrixLayout_t Ddesc,
|
| 384 |
+
const cublasLtMatmulAlgo_t* algo,
|
| 385 |
+
void* workspace,
|
| 386 |
+
size_t workspaceSizeInBytes,
|
| 387 |
+
cudaStream_t stream);
|
| 388 |
+
|
| 389 |
+
/** Matrix layout conversion helper (C = alpha * op(A) + beta * op(B))
|
| 390 |
+
*
|
| 391 |
+
* Can be used to change memory order of data or to scale and shift the values.
|
| 392 |
+
*
|
| 393 |
+
* \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized
|
| 394 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g.
|
| 395 |
+
* when A is not NULL, but Adesc is NULL
|
| 396 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured
|
| 397 |
+
* operation
|
| 398 |
+
* \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device
|
| 399 |
+
* \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device
|
| 400 |
+
* \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully
|
| 401 |
+
*/
|
| 402 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(cublasLtHandle_t lightHandle,
|
| 403 |
+
cublasLtMatrixTransformDesc_t transformDesc,
|
| 404 |
+
const void* alpha, /* host or device pointer */
|
| 405 |
+
const void* A,
|
| 406 |
+
cublasLtMatrixLayout_t Adesc,
|
| 407 |
+
const void* beta, /* host or device pointer */
|
| 408 |
+
const void* B,
|
| 409 |
+
cublasLtMatrixLayout_t Bdesc,
|
| 410 |
+
void* C,
|
| 411 |
+
cublasLtMatrixLayout_t Cdesc,
|
| 412 |
+
cudaStream_t stream);
|
| 413 |
+
|
| 414 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 415 |
+
/* Helper functions for cublasLtMatrixLayout_t */
|
| 416 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 417 |
+
|
| 418 |
+
/** Enum for data ordering */
|
| 419 |
+
typedef enum {
|
| 420 |
+
/** Column-major
|
| 421 |
+
*
|
| 422 |
+
* Leading dimension is the stride (in elements) to the beginning of next column in memory.
|
| 423 |
+
*/
|
| 424 |
+
CUBLASLT_ORDER_COL = 0,
|
| 425 |
+
/** Row major
|
| 426 |
+
*
|
| 427 |
+
* Leading dimension is the stride (in elements) to the beginning of next row in memory.
|
| 428 |
+
*/
|
| 429 |
+
CUBLASLT_ORDER_ROW = 1,
|
| 430 |
+
/** Column-major ordered tiles of 32 columns.
|
| 431 |
+
*
|
| 432 |
+
* Leading dimension is the stride (in elements) to the beginning of next group of 32-columns. E.g. if matrix has 33
|
| 433 |
+
* columns and 2 rows, ld must be at least (32) * 2 = 64.
|
| 434 |
+
*/
|
| 435 |
+
CUBLASLT_ORDER_COL32 = 2,
|
| 436 |
+
/** Column-major ordered tiles of composite tiles with total 32 columns and 8 rows, tile composed of interleaved
|
| 437 |
+
* inner tiles of 4 columns within 4 even or odd rows in an alternating pattern.
|
| 438 |
+
*
|
| 439 |
+
* Leading dimension is the stride (in elements) to the beginning of the first 32 column x 8 row tile for the next
|
| 440 |
+
* 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32 * 8) * 1 = 256.
|
| 441 |
+
*/
|
| 442 |
+
CUBLASLT_ORDER_COL4_4R2_8C = 3,
|
| 443 |
+
/** Column-major ordered tiles of composite tiles with total 32 columns ands 32 rows.
|
| 444 |
+
* Element offset within the tile is calculated as (((row%8)/2*4+row/8)*2+row%2)*32+col.
|
| 445 |
+
*
|
| 446 |
+
* Leading dimension is the stride (in elements) to the beginning of the first 32 column x 32 row tile for the next
|
| 447 |
+
* 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32*32)*1 = 1024.
|
| 448 |
+
*/
|
| 449 |
+
CUBLASLT_ORDER_COL32_2R_4R4 = 4,
|
| 450 |
+
|
| 451 |
+
} cublasLtOrder_t;
|
| 452 |
+
|
| 453 |
+
/** Attributes of memory layout */
|
| 454 |
+
typedef enum {
|
| 455 |
+
/** Data type, see cudaDataType.
|
| 456 |
+
*
|
| 457 |
+
* uint32_t
|
| 458 |
+
*/
|
| 459 |
+
CUBLASLT_MATRIX_LAYOUT_TYPE = 0,
|
| 460 |
+
|
| 461 |
+
/** Memory order of the data, see cublasLtOrder_t.
|
| 462 |
+
*
|
| 463 |
+
* int32_t, default: CUBLASLT_ORDER_COL
|
| 464 |
+
*/
|
| 465 |
+
CUBLASLT_MATRIX_LAYOUT_ORDER = 1,
|
| 466 |
+
|
| 467 |
+
/** Number of rows.
|
| 468 |
+
*
|
| 469 |
+
* Usually only values that can be expressed as int32_t are supported.
|
| 470 |
+
*
|
| 471 |
+
* uint64_t
|
| 472 |
+
*/
|
| 473 |
+
CUBLASLT_MATRIX_LAYOUT_ROWS = 2,
|
| 474 |
+
|
| 475 |
+
/** Number of columns.
|
| 476 |
+
*
|
| 477 |
+
* Usually only values that can be expressed as int32_t are supported.
|
| 478 |
+
*
|
| 479 |
+
* uint64_t
|
| 480 |
+
*/
|
| 481 |
+
CUBLASLT_MATRIX_LAYOUT_COLS = 3,
|
| 482 |
+
|
| 483 |
+
/** Matrix leading dimension.
|
| 484 |
+
*
|
| 485 |
+
* For CUBLASLT_ORDER_COL this is stride (in elements) of matrix column, for more details and documentation for
|
| 486 |
+
* other memory orders see documentation for cublasLtOrder_t values.
|
| 487 |
+
*
|
| 488 |
+
* Currently only non-negative values are supported, must be large enough so that matrix memory locations are not
|
| 489 |
+
* overlapping (e.g. greater or equal to CUBLASLT_MATRIX_LAYOUT_ROWS in case of CUBLASLT_ORDER_COL).
|
| 490 |
+
*
|
| 491 |
+
* int64_t;
|
| 492 |
+
*/
|
| 493 |
+
CUBLASLT_MATRIX_LAYOUT_LD = 4,
|
| 494 |
+
|
| 495 |
+
/** Number of matmul operations to perform in the batch.
|
| 496 |
+
*
|
| 497 |
+
* See also CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT
|
| 498 |
+
*
|
| 499 |
+
* int32_t, default: 1
|
| 500 |
+
*/
|
| 501 |
+
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5,
|
| 502 |
+
|
| 503 |
+
/** Stride (in elements) to the next matrix for strided batch operation.
|
| 504 |
+
*
|
| 505 |
+
* When matrix type is planar-complex (CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET != 0), batch stride
|
| 506 |
+
* is interpreted by cublasLtMatmul() in number of real valued sub-elements. E.g. for data of type CUDA_C_16F,
|
| 507 |
+
* offset of 1024B is encoded as a stride of value 512 (since each element of the real and imaginary matrices
|
| 508 |
+
* is a 2B (16bit) floating point type).
|
| 509 |
+
*
|
| 510 |
+
* NOTE: A bug in cublasLtMatrixTransform() causes it to interpret the batch stride for a planar-complex matrix
|
| 511 |
+
* as if it was specified in number of complex elements. Therefore an offset of 1024B must be encoded as stride
|
| 512 |
+
* value 256 when calling cublasLtMatrixTransform() (each complex element is 4B with real and imaginary values 2B
|
| 513 |
+
* each). This behavior is expected to be corrected in the next major cuBLAS version.
|
| 514 |
+
*
|
| 515 |
+
* int64_t, default: 0
|
| 516 |
+
*/
|
| 517 |
+
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6,
|
| 518 |
+
|
| 519 |
+
/** Stride (in bytes) to the imaginary plane for planar complex layout.
|
| 520 |
+
*
|
| 521 |
+
* int64_t, default: 0 - 0 means that layout is regular (real and imaginary parts of complex numbers are interleaved
|
| 522 |
+
* in memory in each element)
|
| 523 |
+
*/
|
| 524 |
+
CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET = 7,
|
| 525 |
+
} cublasLtMatrixLayoutAttribute_t;
|
| 526 |
+
|
| 527 |
+
/** Internal. Do not use directly.
|
| 528 |
+
*/
|
| 529 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( //
|
| 530 |
+
cublasLtMatrixLayout_t matLayout,
|
| 531 |
+
size_t size,
|
| 532 |
+
cudaDataType type,
|
| 533 |
+
uint64_t rows,
|
| 534 |
+
uint64_t cols,
|
| 535 |
+
int64_t ld);
|
| 536 |
+
|
| 537 |
+
/** Initialize matrix layout descriptor in pre-allocated space.
|
| 538 |
+
*
|
| 539 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
|
| 540 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 541 |
+
*/
|
| 542 |
+
static inline cublasStatus_t cublasLtMatrixLayoutInit(
|
| 543 |
+
cublasLtMatrixLayout_t matLayout, cudaDataType type, uint64_t rows, uint64_t cols, int64_t ld) {
|
| 544 |
+
return cublasLtMatrixLayoutInit_internal(matLayout, sizeof(*matLayout), type, rows, cols, ld);
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
/** Create new matrix layout descriptor.
|
| 548 |
+
*
|
| 549 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
|
| 550 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 551 |
+
*/
|
| 552 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( //
|
| 553 |
+
cublasLtMatrixLayout_t* matLayout,
|
| 554 |
+
cudaDataType type,
|
| 555 |
+
uint64_t rows,
|
| 556 |
+
uint64_t cols,
|
| 557 |
+
int64_t ld);
|
| 558 |
+
|
| 559 |
+
/** Destroy matrix layout descriptor.
|
| 560 |
+
*
|
| 561 |
+
* \retval CUBLAS_STATUS_SUCCESS if operation was successful
|
| 562 |
+
*/
|
| 563 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout);
|
| 564 |
+
|
| 565 |
+
/** Set matrix layout descriptor attribute.
|
| 566 |
+
*
|
| 567 |
+
* \param[in] matLayout The descriptor
|
| 568 |
+
* \param[in] attr The attribute
|
| 569 |
+
* \param[in] buf memory address containing the new value
|
| 570 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 571 |
+
*
|
| 572 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 573 |
+
* selected attribute
|
| 574 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 575 |
+
*/
|
| 576 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( //
|
| 577 |
+
cublasLtMatrixLayout_t matLayout,
|
| 578 |
+
cublasLtMatrixLayoutAttribute_t attr,
|
| 579 |
+
const void* buf,
|
| 580 |
+
size_t sizeInBytes);
|
| 581 |
+
|
| 582 |
+
/** Get matrix layout descriptor attribute.
|
| 583 |
+
*
|
| 584 |
+
* \param[in] matLayout The descriptor
|
| 585 |
+
* \param[in] attr The attribute
|
| 586 |
+
* \param[out] buf memory address containing the new value
|
| 587 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 588 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 589 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 590 |
+
*
|
| 591 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 592 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 593 |
+
* selected attribute
|
| 594 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 595 |
+
*/
|
| 596 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( //
|
| 597 |
+
cublasLtMatrixLayout_t matLayout,
|
| 598 |
+
cublasLtMatrixLayoutAttribute_t attr,
|
| 599 |
+
void* buf,
|
| 600 |
+
size_t sizeInBytes,
|
| 601 |
+
size_t* sizeWritten);
|
| 602 |
+
|
| 603 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 604 |
+
/* Helper functions for cublasLtMatmulDesc_t */
|
| 605 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 606 |
+
|
| 607 |
+
/** Matmul descriptor attributes to define details of the operation. */
|
| 608 |
+
typedef enum {
|
| 609 |
+
/** Compute type, see cudaDataType. Defines data type used for multiply and accumulate operations and the
|
| 610 |
+
* accumulator during matrix multiplication.
|
| 611 |
+
*
|
| 612 |
+
* int32_t
|
| 613 |
+
*/
|
| 614 |
+
CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0,
|
| 615 |
+
|
| 616 |
+
/** Scale type, see cudaDataType. Defines data type of alpha and beta. Accumulator and value from matrix C are
|
| 617 |
+
* typically converted to scale type before final scaling. Value is then converted from scale type to type of matrix
|
| 618 |
+
* D before being stored in memory.
|
| 619 |
+
*
|
| 620 |
+
* int32_t, default: same as CUBLASLT_MATMUL_DESC_COMPUTE_TYPE
|
| 621 |
+
*/
|
| 622 |
+
CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1,
|
| 623 |
+
|
| 624 |
+
/** Pointer mode of alpha and beta, see cublasLtPointerMode_t. When CUBLASLT_POINTER_MODE_DEVICE_VECTOR is in use,
|
| 625 |
+
* alpha/beta vector lenghts must match number of output matrix rows.
|
| 626 |
+
*
|
| 627 |
+
* int32_t, default: CUBLASLT_POINTER_MODE_HOST
|
| 628 |
+
*/
|
| 629 |
+
CUBLASLT_MATMUL_DESC_POINTER_MODE = 2,
|
| 630 |
+
|
| 631 |
+
/** Transform of matrix A, see cublasOperation_t.
|
| 632 |
+
*
|
| 633 |
+
* int32_t, default: CUBLAS_OP_N
|
| 634 |
+
*/
|
| 635 |
+
CUBLASLT_MATMUL_DESC_TRANSA = 3,
|
| 636 |
+
|
| 637 |
+
/** Transform of matrix B, see cublasOperation_t.
|
| 638 |
+
*
|
| 639 |
+
* int32_t, default: CUBLAS_OP_N
|
| 640 |
+
*/
|
| 641 |
+
CUBLASLT_MATMUL_DESC_TRANSB = 4,
|
| 642 |
+
|
| 643 |
+
/** Transform of matrix C, see cublasOperation_t.
|
| 644 |
+
*
|
| 645 |
+
* Currently only CUBLAS_OP_N is supported.
|
| 646 |
+
*
|
| 647 |
+
* int32_t, default: CUBLAS_OP_N
|
| 648 |
+
*/
|
| 649 |
+
CUBLASLT_MATMUL_DESC_TRANSC = 5,
|
| 650 |
+
|
| 651 |
+
/** Matrix fill mode, see cublasFillMode_t.
|
| 652 |
+
*
|
| 653 |
+
* int32_t, default: CUBLAS_FILL_MODE_FULL
|
| 654 |
+
*/
|
| 655 |
+
CUBLASLT_MATMUL_DESC_FILL_MODE = 6,
|
| 656 |
+
|
| 657 |
+
/** Epilogue function, see cublasLtEpilogue_t.
|
| 658 |
+
*
|
| 659 |
+
* uint32_t, default: CUBLASLT_EPILOGUE_DEFAULT
|
| 660 |
+
*/
|
| 661 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE = 7,
|
| 662 |
+
|
| 663 |
+
/** Bias or bias gradient vector pointer in the device memory.
|
| 664 |
+
*
|
| 665 |
+
* Bias case. See CUBLASLT_EPILOGUE_BIAS.
|
| 666 |
+
* For bias data type see CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE.
|
| 667 |
+
*
|
| 668 |
+
* Bias vector length must match matrix D rows count.
|
| 669 |
+
*
|
| 670 |
+
* Bias gradient case. See CUBLASLT_EPILOGUE_DRELU_BGRAD and CUBLASLT_EPILOGUE_DGELU_BGRAD.
|
| 671 |
+
* Bias gradient vector elements are the same type as the output elements
|
| 672 |
+
* (Ctype) with the exception of IMMA kernels (see above).
|
| 673 |
+
*
|
| 674 |
+
* Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic()
|
| 675 |
+
* depend on its value to determine expected pointer alignment.
|
| 676 |
+
*
|
| 677 |
+
* Bias case: const void *, default: NULL
|
| 678 |
+
* Bias gradient case: void *, default: NULL
|
| 679 |
+
*/
|
| 680 |
+
CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8,
|
| 681 |
+
|
| 682 |
+
/** Batch stride for bias or bias gradient vector.
|
| 683 |
+
*
|
| 684 |
+
* Used together with CUBLASLT_MATMUL_DESC_BIAS_POINTER when matrix D's CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1.
|
| 685 |
+
*
|
| 686 |
+
* int64_t, default: 0
|
| 687 |
+
*/
|
| 688 |
+
CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE = 10,
|
| 689 |
+
|
| 690 |
+
/** Pointer for epilogue auxiliary buffer.
|
| 691 |
+
*
|
| 692 |
+
* - Output vector for ReLu bit-mask in forward pass when CUBLASLT_EPILOGUE_RELU_AUX
|
| 693 |
+
* or CUBLASLT_EPILOGUE_RELU_AUX_BIAS epilogue is used.
|
| 694 |
+
* - Input vector for ReLu bit-mask in backward pass when
|
| 695 |
+
* CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is used.
|
| 696 |
+
*
|
| 697 |
+
* - Output of GELU input matrix in forward pass when
|
| 698 |
+
* CUBLASLT_EPILOGUE_GELU_AUX_BIAS epilogue is used.
|
| 699 |
+
* - Input of GELU input matrix for backward pass when
|
| 700 |
+
* CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue is used.
|
| 701 |
+
*
|
| 702 |
+
* For aux data type see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE.
|
| 703 |
+
*
|
| 704 |
+
* Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic()
|
| 705 |
+
* depend on its value to determine expected pointer alignment.
|
| 706 |
+
*
|
| 707 |
+
* Requires setting CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD attribute.
|
| 708 |
+
*
|
| 709 |
+
* Forward pass: void *, default: NULL
|
| 710 |
+
* Backward pass: const void *, default: NULL
|
| 711 |
+
*/
|
| 712 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER = 11,
|
| 713 |
+
|
| 714 |
+
/** Leading dimension for epilogue auxiliary buffer.
|
| 715 |
+
*
|
| 716 |
+
* - ReLu bit-mask matrix leading dimension in elements (i.e. bits)
|
| 717 |
+
* when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is
|
| 718 |
+
* used. Must be divisible by 128 and be no less than the number of rows in the output matrix.
|
| 719 |
+
*
|
| 720 |
+
* - GELU input matrix leading dimension in elements
|
| 721 |
+
* when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used.
|
| 722 |
+
* Must be divisible by 8 and be no less than the number of rows in the output matrix.
|
| 723 |
+
*
|
| 724 |
+
* int64_t, default: 0
|
| 725 |
+
*/
|
| 726 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD = 12,
|
| 727 |
+
|
| 728 |
+
/** Batch stride for epilogue auxiliary buffer.
|
| 729 |
+
*
|
| 730 |
+
* - ReLu bit-mask matrix batch stride in elements (i.e. bits)
|
| 731 |
+
* when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is
|
| 732 |
+
* used. Must be divisible by 128.
|
| 733 |
+
*
|
| 734 |
+
* - GELU input matrix batch stride in elements
|
| 735 |
+
* when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used.
|
| 736 |
+
* Must be divisible by 8.
|
| 737 |
+
*
|
| 738 |
+
* int64_t, default: 0
|
| 739 |
+
*/
|
| 740 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE = 13,
|
| 741 |
+
|
| 742 |
+
/** Batch stride for alpha vector.
|
| 743 |
+
*
|
| 744 |
+
* Used together with CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST when matrix D's
|
| 745 |
+
* CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1. If CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO is set then
|
| 746 |
+
* CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE must be set to 0 as this mode doesnt supported batched alpha vector.
|
| 747 |
+
*
|
| 748 |
+
* int64_t, default: 0
|
| 749 |
+
*/
|
| 750 |
+
CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE = 14,
|
| 751 |
+
|
| 752 |
+
/** Number of SMs to target for parallel execution. Optimizes heuristics for execution on a different number of SMs
|
| 753 |
+
* when user expects a concurrent stream to be using some of the device resources.
|
| 754 |
+
*
|
| 755 |
+
* int32_t, default: 0 - use the number reported by the device.
|
| 756 |
+
*/
|
| 757 |
+
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 15,
|
| 758 |
+
|
| 759 |
+
/** Device pointer to the scale factor value that converts data in matrix A to the compute data type range.
|
| 760 |
+
*
|
| 761 |
+
* The scaling factor value must have the same type as the compute type.
|
| 762 |
+
*
|
| 763 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1.
|
| 764 |
+
*
|
| 765 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 766 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 767 |
+
*
|
| 768 |
+
* const void *, default: NULL
|
| 769 |
+
*/
|
| 770 |
+
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 17,
|
| 771 |
+
|
| 772 |
+
/** Device pointer to the scale factor value to convert data in matrix B to compute data type range.
|
| 773 |
+
*
|
| 774 |
+
* The scaling factor value must have the same type as the compute type.
|
| 775 |
+
*
|
| 776 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1.
|
| 777 |
+
*
|
| 778 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 779 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 780 |
+
*
|
| 781 |
+
* const void *, default: NULL
|
| 782 |
+
*/
|
| 783 |
+
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 18,
|
| 784 |
+
|
| 785 |
+
/** Device pointer to the scale factor value to convert data in matrix C to compute data type range.
|
| 786 |
+
*
|
| 787 |
+
* The scaling factor value must have the same type as the compute type.
|
| 788 |
+
*
|
| 789 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1.
|
| 790 |
+
*
|
| 791 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 792 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 793 |
+
*
|
| 794 |
+
* const void *, default: NULL
|
| 795 |
+
*/
|
| 796 |
+
CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 19,
|
| 797 |
+
|
| 798 |
+
/** Device pointer to the scale factor value to convert data in matrix D to compute data type range.
|
| 799 |
+
*
|
| 800 |
+
* The scaling factor value must have the same type as the compute type.
|
| 801 |
+
*
|
| 802 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1.
|
| 803 |
+
*
|
| 804 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 805 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 806 |
+
*
|
| 807 |
+
* const void *, default: NULL
|
| 808 |
+
*/
|
| 809 |
+
CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 20,
|
| 810 |
+
|
| 811 |
+
/** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the
|
| 812 |
+
* output matrix.
|
| 813 |
+
*
|
| 814 |
+
* The computed value has the same type as the compute type.
|
| 815 |
+
*
|
| 816 |
+
* If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix
|
| 817 |
+
* data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
|
| 818 |
+
*
|
| 819 |
+
* void *, default: NULL
|
| 820 |
+
*/
|
| 821 |
+
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 21,
|
| 822 |
+
|
| 823 |
+
/** Type of the data to be stored to the memory pointed to by CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 824 |
+
*
|
| 825 |
+
* If unset, the data type defaults to the type of elements of the output matrix with some exceptions, see details
|
| 826 |
+
* below.
|
| 827 |
+
*
|
| 828 |
+
* ReLu uses a bit-mask.
|
| 829 |
+
*
|
| 830 |
+
* GELU input matrix elements type is the same as the type of elements of
|
| 831 |
+
* the output matrix with some exceptions, see details below.
|
| 832 |
+
*
|
| 833 |
+
* For fp8 kernels with output type CUDA_R_8F_E4M3 the aux data type can be CUDA_R_8F_E4M3 or CUDA_R_16F with some
|
| 834 |
+
* restrictions. See https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulDescAttributes_t for more details.
|
| 835 |
+
*
|
| 836 |
+
* If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul()
|
| 837 |
+
* will return CUBLAS_INVALID_VALUE.
|
| 838 |
+
*
|
| 839 |
+
* int32_t based on cudaDataType, default: -1
|
| 840 |
+
*/
|
| 841 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE = 22,
|
| 842 |
+
|
| 843 |
+
/** Device pointer to the scaling factor value to convert results from compute type data range to storage
|
| 844 |
+
* data range in the auxiliary matrix that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 845 |
+
*
|
| 846 |
+
* The scaling factor value must have the same type as the compute type.
|
| 847 |
+
*
|
| 848 |
+
* If not specified, or set to NULL, the scaling factor is assumed to be 1. If set for an unsupported matrix data,
|
| 849 |
+
* scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
|
| 850 |
+
*
|
| 851 |
+
* void *, default: NULL
|
| 852 |
+
*/
|
| 853 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER = 23,
|
| 854 |
+
|
| 855 |
+
/** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the
|
| 856 |
+
* buffer that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 857 |
+
*
|
| 858 |
+
* The computed value has the same type as the compute type.
|
| 859 |
+
*
|
| 860 |
+
* If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix
|
| 861 |
+
* data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE.
|
| 862 |
+
*
|
| 863 |
+
* void *, default: NULL
|
| 864 |
+
*/
|
| 865 |
+
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER = 24,
|
| 866 |
+
|
| 867 |
+
/** Flag for managing fp8 fast accumulation mode.
|
| 868 |
+
* When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results
|
| 869 |
+
* will not periodically be promoted to a higher precision.
|
| 870 |
+
*
|
| 871 |
+
* int8_t, default: 0 - fast accumulation mode is disabled.
|
| 872 |
+
*/
|
| 873 |
+
CUBLASLT_MATMUL_DESC_FAST_ACCUM = 25,
|
| 874 |
+
|
| 875 |
+
/** Type of bias or bias gradient vector in the device memory.
|
| 876 |
+
*
|
| 877 |
+
* Bias case: see CUBLASLT_EPILOGUE_BIAS.
|
| 878 |
+
*
|
| 879 |
+
* Bias vector elements are the same type as the elements of output matrix (Dtype) with the following exceptions:
|
| 880 |
+
* - IMMA kernels with computeType=CUDA_R_32I and Ctype=CUDA_R_8I where the bias vector elements
|
| 881 |
+
* are the same type as alpha, beta (CUBLASLT_MATMUL_DESC_SCALE_TYPE=CUDA_R_32F)
|
| 882 |
+
* - fp8 kernels with an output type of CUDA_R_32F, CUDA_R_8F_E4M3 or CUDA_R_8F_E5M2, See
|
| 883 |
+
* https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul for details.
|
| 884 |
+
*
|
| 885 |
+
* int32_t based on cudaDataType, default: -1
|
| 886 |
+
*/
|
| 887 |
+
CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 26,
|
| 888 |
+
} cublasLtMatmulDescAttributes_t;
|
| 889 |
+
|
| 890 |
+
/** Internal. Do not use directly.
|
| 891 |
+
*/
|
| 892 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( //
|
| 893 |
+
cublasLtMatmulDesc_t matmulDesc,
|
| 894 |
+
size_t size,
|
| 895 |
+
cublasComputeType_t computeType,
|
| 896 |
+
cudaDataType_t scaleType);
|
| 897 |
+
|
| 898 |
+
/** Initialize matmul operation descriptor in pre-allocated space.
|
| 899 |
+
*
|
| 900 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
|
| 901 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was initialized successfully
|
| 902 |
+
*/
|
| 903 |
+
static inline cublasStatus_t cublasLtMatmulDescInit( //
|
| 904 |
+
cublasLtMatmulDesc_t matmulDesc,
|
| 905 |
+
cublasComputeType_t computeType,
|
| 906 |
+
cudaDataType_t scaleType) {
|
| 907 |
+
return cublasLtMatmulDescInit_internal(matmulDesc, sizeof(*matmulDesc), computeType, scaleType);
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
/** Create new matmul operation descriptor.
|
| 911 |
+
*
|
| 912 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
|
| 913 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 914 |
+
*/
|
| 915 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate(cublasLtMatmulDesc_t* matmulDesc,
|
| 916 |
+
cublasComputeType_t computeType,
|
| 917 |
+
cudaDataType_t scaleType);
|
| 918 |
+
|
| 919 |
+
/** Destroy matmul operation descriptor.
|
| 920 |
+
*
|
| 921 |
+
* \retval CUBLAS_STATUS_SUCCESS if operation was successful
|
| 922 |
+
*/
|
| 923 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc);
|
| 924 |
+
|
| 925 |
+
/** Set matmul operation descriptor attribute.
|
| 926 |
+
*
|
| 927 |
+
* \param[in] matmulDesc The descriptor
|
| 928 |
+
* \param[in] attr The attribute
|
| 929 |
+
* \param[in] buf memory address containing the new value
|
| 930 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 931 |
+
*
|
| 932 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 933 |
+
* selected attribute
|
| 934 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 935 |
+
*/
|
| 936 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( //
|
| 937 |
+
cublasLtMatmulDesc_t matmulDesc,
|
| 938 |
+
cublasLtMatmulDescAttributes_t attr,
|
| 939 |
+
const void* buf,
|
| 940 |
+
size_t sizeInBytes);
|
| 941 |
+
|
| 942 |
+
/** Get matmul operation descriptor attribute.
|
| 943 |
+
*
|
| 944 |
+
* \param[in] matmulDesc The descriptor
|
| 945 |
+
* \param[in] attr The attribute
|
| 946 |
+
* \param[out] buf memory address containing the new value
|
| 947 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 948 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 949 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 950 |
+
*
|
| 951 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 952 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 953 |
+
* selected attribute
|
| 954 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 955 |
+
*/
|
| 956 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( //
|
| 957 |
+
cublasLtMatmulDesc_t matmulDesc,
|
| 958 |
+
cublasLtMatmulDescAttributes_t attr,
|
| 959 |
+
void* buf,
|
| 960 |
+
size_t sizeInBytes,
|
| 961 |
+
size_t* sizeWritten);
|
| 962 |
+
|
| 963 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 964 |
+
/* Helper functions for cublasLtMatrixTransformDesc_t */
|
| 965 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 966 |
+
|
| 967 |
+
/** Matrix transform descriptor attributes to define details of the operation.
|
| 968 |
+
*/
|
| 969 |
+
typedef enum {
|
| 970 |
+
/** Scale type, see cudaDataType. Inputs are converted to scale type for scaling and summation and results are then
|
| 971 |
+
* converted to output type to store in memory.
|
| 972 |
+
*
|
| 973 |
+
* int32_t
|
| 974 |
+
*/
|
| 975 |
+
CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE,
|
| 976 |
+
|
| 977 |
+
/** Pointer mode of alpha and beta, see cublasLtPointerMode_t.
|
| 978 |
+
*
|
| 979 |
+
* int32_t, default: CUBLASLT_POINTER_MODE_HOST
|
| 980 |
+
*/
|
| 981 |
+
CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE,
|
| 982 |
+
|
| 983 |
+
/** Transform of matrix A, see cublasOperation_t.
|
| 984 |
+
*
|
| 985 |
+
* int32_t, default: CUBLAS_OP_N
|
| 986 |
+
*/
|
| 987 |
+
CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA,
|
| 988 |
+
|
| 989 |
+
/** Transform of matrix B, see cublasOperation_t.
|
| 990 |
+
*
|
| 991 |
+
* int32_t, default: CUBLAS_OP_N
|
| 992 |
+
*/
|
| 993 |
+
CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSB,
|
| 994 |
+
} cublasLtMatrixTransformDescAttributes_t;
|
| 995 |
+
|
| 996 |
+
/** Internal. Do not use directly.
|
| 997 |
+
*/
|
| 998 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc,
|
| 999 |
+
size_t size,
|
| 1000 |
+
cudaDataType scaleType);
|
| 1001 |
+
|
| 1002 |
+
/** Initialize matrix transform operation descriptor in pre-allocated space.
|
| 1003 |
+
*
|
| 1004 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
|
| 1005 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 1006 |
+
*/
|
| 1007 |
+
static inline cublasStatus_t cublasLtMatrixTransformDescInit(cublasLtMatrixTransformDesc_t transformDesc,
|
| 1008 |
+
cudaDataType scaleType) {
|
| 1009 |
+
return cublasLtMatrixTransformDescInit_internal(transformDesc, sizeof(*transformDesc), scaleType);
|
| 1010 |
+
}
|
| 1011 |
+
|
| 1012 |
+
/** Create new matrix transform operation descriptor.
|
| 1013 |
+
*
|
| 1014 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
|
| 1015 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 1016 |
+
*/
|
| 1017 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t* transformDesc,
|
| 1018 |
+
cudaDataType scaleType);
|
| 1019 |
+
|
| 1020 |
+
/** Destroy matrix transform operation descriptor.
|
| 1021 |
+
*
|
| 1022 |
+
* \retval CUBLAS_STATUS_SUCCESS if operation was successful
|
| 1023 |
+
*/
|
| 1024 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc);
|
| 1025 |
+
|
| 1026 |
+
/** Set matrix transform operation descriptor attribute.
|
| 1027 |
+
*
|
| 1028 |
+
* \param[in] transformDesc The descriptor
|
| 1029 |
+
* \param[in] attr The attribute
|
| 1030 |
+
* \param[in] buf memory address containing the new value
|
| 1031 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1032 |
+
*
|
| 1033 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1034 |
+
* selected attribute
|
| 1035 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 1036 |
+
*/
|
| 1037 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( //
|
| 1038 |
+
cublasLtMatrixTransformDesc_t transformDesc,
|
| 1039 |
+
cublasLtMatrixTransformDescAttributes_t attr,
|
| 1040 |
+
const void* buf,
|
| 1041 |
+
size_t sizeInBytes);
|
| 1042 |
+
|
| 1043 |
+
/** Get matrix transform operation descriptor attribute.
|
| 1044 |
+
*
|
| 1045 |
+
* \param[in] transformDesc The descriptor
|
| 1046 |
+
* \param[in] attr The attribute
|
| 1047 |
+
* \param[out] buf memory address containing the new value
|
| 1048 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1049 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number
|
| 1050 |
+
* of bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 1051 |
+
*
|
| 1052 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 1053 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1054 |
+
* selected attribute
|
| 1055 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 1056 |
+
*/
|
| 1057 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( //
|
| 1058 |
+
cublasLtMatrixTransformDesc_t transformDesc,
|
| 1059 |
+
cublasLtMatrixTransformDescAttributes_t attr,
|
| 1060 |
+
void* buf,
|
| 1061 |
+
size_t sizeInBytes,
|
| 1062 |
+
size_t* sizeWritten);
|
| 1063 |
+
|
| 1064 |
+
/** For computation with complex numbers, this enum allows to apply the Gauss Complexity reduction algorithm
|
| 1065 |
+
*/
|
| 1066 |
+
typedef enum {
|
| 1067 |
+
CUBLASLT_3M_MODE_DISALLOWED = 0,
|
| 1068 |
+
CUBLASLT_3M_MODE_ALLOWED = 1,
|
| 1069 |
+
} cublasLt3mMode_t;
|
| 1070 |
+
|
| 1071 |
+
/** Reduction scheme for portions of the dot-product calculated in parallel (a. k. a. "split - K").
|
| 1072 |
+
*/
|
| 1073 |
+
typedef enum {
|
| 1074 |
+
/** No reduction scheme, dot-product shall be performed in one sequence.
|
| 1075 |
+
*/
|
| 1076 |
+
CUBLASLT_REDUCTION_SCHEME_NONE = 0,
|
| 1077 |
+
|
| 1078 |
+
/** Reduction is performed "in place" - using the output buffer (and output data type) and counters (in workspace) to
|
| 1079 |
+
* guarantee the sequentiality.
|
| 1080 |
+
*/
|
| 1081 |
+
CUBLASLT_REDUCTION_SCHEME_INPLACE = 1,
|
| 1082 |
+
|
| 1083 |
+
/** Intermediate results are stored in compute type in the workspace and reduced in a separate step.
|
| 1084 |
+
*/
|
| 1085 |
+
CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE = 2,
|
| 1086 |
+
|
| 1087 |
+
/** Intermediate results are stored in output type in the workspace and reduced in a separate step.
|
| 1088 |
+
*/
|
| 1089 |
+
CUBLASLT_REDUCTION_SCHEME_OUTPUT_TYPE = 4,
|
| 1090 |
+
|
| 1091 |
+
CUBLASLT_REDUCTION_SCHEME_MASK = 0x7,
|
| 1092 |
+
} cublasLtReductionScheme_t;
|
| 1093 |
+
|
| 1094 |
+
/** Postprocessing options for the epilogue
|
| 1095 |
+
*/
|
| 1096 |
+
typedef enum {
|
| 1097 |
+
/** No special postprocessing, just scale and quantize results if necessary.
|
| 1098 |
+
*/
|
| 1099 |
+
CUBLASLT_EPILOGUE_DEFAULT = 1,
|
| 1100 |
+
|
| 1101 |
+
/** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)).
|
| 1102 |
+
*/
|
| 1103 |
+
CUBLASLT_EPILOGUE_RELU = 2,
|
| 1104 |
+
|
| 1105 |
+
/** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)).
|
| 1106 |
+
*
|
| 1107 |
+
* This epilogue mode produces an extra output, a ReLu bit-mask matrix,
|
| 1108 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1109 |
+
*/
|
| 1110 |
+
CUBLASLT_EPILOGUE_RELU_AUX = (CUBLASLT_EPILOGUE_RELU | 128),
|
| 1111 |
+
|
| 1112 |
+
/** Bias, apply (broadcasted) Bias from bias vector. Bias vector length must match matrix D rows, it must be packed
|
| 1113 |
+
* (stride between vector elements is 1). Bias vector is broadcasted to all columns and added before applying final
|
| 1114 |
+
* postprocessing.
|
| 1115 |
+
*/
|
| 1116 |
+
CUBLASLT_EPILOGUE_BIAS = 4,
|
| 1117 |
+
|
| 1118 |
+
/** ReLu and Bias, apply Bias and then ReLu transform
|
| 1119 |
+
*/
|
| 1120 |
+
CUBLASLT_EPILOGUE_RELU_BIAS = (CUBLASLT_EPILOGUE_RELU | CUBLASLT_EPILOGUE_BIAS),
|
| 1121 |
+
|
| 1122 |
+
/** ReLu and Bias, apply Bias and then ReLu transform
|
| 1123 |
+
*
|
| 1124 |
+
* This epilogue mode produces an extra output, a ReLu bit-mask matrix,
|
| 1125 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1126 |
+
*/
|
| 1127 |
+
CUBLASLT_EPILOGUE_RELU_AUX_BIAS = (CUBLASLT_EPILOGUE_RELU_AUX | CUBLASLT_EPILOGUE_BIAS),
|
| 1128 |
+
|
| 1129 |
+
/* ReLu gradient. Apply ReLu gradient to matmul output. Store ReLu gradient in the output matrix.
|
| 1130 |
+
*
|
| 1131 |
+
* This epilogue mode requires an extra input,
|
| 1132 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1133 |
+
*/
|
| 1134 |
+
CUBLASLT_EPILOGUE_DRELU = 8 | 128,
|
| 1135 |
+
|
| 1136 |
+
/* ReLu and Bias gradients. Apply independently ReLu and Bias gradient to
|
| 1137 |
+
* matmul output. Store ReLu gradient in the output matrix, and Bias gradient
|
| 1138 |
+
* in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
|
| 1139 |
+
*
|
| 1140 |
+
* This epilogue mode requires an extra input,
|
| 1141 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1142 |
+
*/
|
| 1143 |
+
CUBLASLT_EPILOGUE_DRELU_BGRAD = CUBLASLT_EPILOGUE_DRELU | 16,
|
| 1144 |
+
|
| 1145 |
+
/** GELU, apply GELU point-wise transform to the results (x:=GELU(x)).
|
| 1146 |
+
*/
|
| 1147 |
+
CUBLASLT_EPILOGUE_GELU = 32,
|
| 1148 |
+
|
| 1149 |
+
/** GELU, apply GELU point-wise transform to the results (x:=GELU(x)).
|
| 1150 |
+
*
|
| 1151 |
+
* This epilogue mode outputs GELU input as a separate matrix (useful for training).
|
| 1152 |
+
* See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1153 |
+
*/
|
| 1154 |
+
CUBLASLT_EPILOGUE_GELU_AUX = (CUBLASLT_EPILOGUE_GELU | 128),
|
| 1155 |
+
|
| 1156 |
+
/** GELU and Bias, apply Bias and then GELU transform
|
| 1157 |
+
*/
|
| 1158 |
+
CUBLASLT_EPILOGUE_GELU_BIAS = (CUBLASLT_EPILOGUE_GELU | CUBLASLT_EPILOGUE_BIAS),
|
| 1159 |
+
|
| 1160 |
+
/** GELU and Bias, apply Bias and then GELU transform
|
| 1161 |
+
*
|
| 1162 |
+
* This epilogue mode outputs GELU input as a separate matrix (useful for training).
|
| 1163 |
+
* See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1164 |
+
*/
|
| 1165 |
+
CUBLASLT_EPILOGUE_GELU_AUX_BIAS = (CUBLASLT_EPILOGUE_GELU_AUX | CUBLASLT_EPILOGUE_BIAS),
|
| 1166 |
+
|
| 1167 |
+
/* GELU gradient. Apply GELU gradient to matmul output. Store GELU gradient in the output matrix.
|
| 1168 |
+
*
|
| 1169 |
+
* This epilogue mode requires an extra input,
|
| 1170 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1171 |
+
*/
|
| 1172 |
+
CUBLASLT_EPILOGUE_DGELU = 64 | 128,
|
| 1173 |
+
|
| 1174 |
+
/* GELU and Bias gradients. Apply independently GELU and Bias gradient to
|
| 1175 |
+
* matmul output. Store GELU gradient in the output matrix, and Bias gradient
|
| 1176 |
+
* in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
|
| 1177 |
+
*
|
| 1178 |
+
* This epilogue mode requires an extra input,
|
| 1179 |
+
* see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER.
|
| 1180 |
+
*/
|
| 1181 |
+
CUBLASLT_EPILOGUE_DGELU_BGRAD = CUBLASLT_EPILOGUE_DGELU | 16,
|
| 1182 |
+
|
| 1183 |
+
/** Bias gradient based on the input matrix A.
|
| 1184 |
+
*
|
| 1185 |
+
* The bias size corresponds to the number of rows of the matrix D.
|
| 1186 |
+
* The reduction happens over the GEMM's "k" dimension.
|
| 1187 |
+
*
|
| 1188 |
+
* Stores Bias gradient in the auxiliary output
|
| 1189 |
+
* (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
|
| 1190 |
+
*/
|
| 1191 |
+
CUBLASLT_EPILOGUE_BGRADA = 256,
|
| 1192 |
+
|
| 1193 |
+
/** Bias gradient based on the input matrix B.
|
| 1194 |
+
*
|
| 1195 |
+
* The bias size corresponds to the number of columns of the matrix D.
|
| 1196 |
+
* The reduction happens over the GEMM's "k" dimension.
|
| 1197 |
+
*
|
| 1198 |
+
* Stores Bias gradient in the auxiliary output
|
| 1199 |
+
* (see CUBLASLT_MATMUL_DESC_BIAS_POINTER).
|
| 1200 |
+
*/
|
| 1201 |
+
CUBLASLT_EPILOGUE_BGRADB = 512,
|
| 1202 |
+
} cublasLtEpilogue_t;
|
| 1203 |
+
|
| 1204 |
+
/** Matmul heuristic search mode
|
| 1205 |
+
*/
|
| 1206 |
+
typedef enum {
|
| 1207 |
+
/** ask heuristics for best algo for given usecase
|
| 1208 |
+
*/
|
| 1209 |
+
CUBLASLT_SEARCH_BEST_FIT = 0,
|
| 1210 |
+
/** only try to find best config for preconfigured algo id
|
| 1211 |
+
*/
|
| 1212 |
+
CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID = 1,
|
| 1213 |
+
/** reserved for future use
|
| 1214 |
+
*/
|
| 1215 |
+
CUBLASLT_SEARCH_RESERVED_02 = 2,
|
| 1216 |
+
/** reserved for future use
|
| 1217 |
+
*/
|
| 1218 |
+
CUBLASLT_SEARCH_RESERVED_03 = 3,
|
| 1219 |
+
/** reserved for future use
|
| 1220 |
+
*/
|
| 1221 |
+
CUBLASLT_SEARCH_RESERVED_04 = 4,
|
| 1222 |
+
/** reserved for future use
|
| 1223 |
+
*/
|
| 1224 |
+
CUBLASLT_SEARCH_RESERVED_05 = 5,
|
| 1225 |
+
} cublasLtMatmulSearch_t;
|
| 1226 |
+
|
| 1227 |
+
/** Algo search preference to fine tune the heuristic function. */
|
| 1228 |
+
typedef enum {
|
| 1229 |
+
/** Search mode, see cublasLtMatmulSearch_t.
|
| 1230 |
+
*
|
| 1231 |
+
* uint32_t, default: CUBLASLT_SEARCH_BEST_FIT
|
| 1232 |
+
*/
|
| 1233 |
+
CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0,
|
| 1234 |
+
|
| 1235 |
+
/** Maximum allowed workspace size in bytes.
|
| 1236 |
+
*
|
| 1237 |
+
* uint64_t, default: 0 - no workspace allowed
|
| 1238 |
+
*/
|
| 1239 |
+
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1,
|
| 1240 |
+
|
| 1241 |
+
/** Math mode mask, see cublasMath_t.
|
| 1242 |
+
*
|
| 1243 |
+
* Only algorithms with CUBLASLT_ALGO_CAP_MATHMODE_IMPL that is not masked out by this attribute are allowed.
|
| 1244 |
+
*
|
| 1245 |
+
* uint32_t, default: 1 (allows both default and tensor op math)
|
| 1246 |
+
* DEPRECATED, will be removed in a future release, see cublasLtNumericalImplFlags_t for replacement
|
| 1247 |
+
*/
|
| 1248 |
+
CUBLASLT_MATMUL_PREF_MATH_MODE_MASK = 2,
|
| 1249 |
+
|
| 1250 |
+
/** Reduction scheme mask, see cublasLtReductionScheme_t. Filters heuristic result to only include algo configs that
|
| 1251 |
+
* use one of the required modes.
|
| 1252 |
+
*
|
| 1253 |
+
* E.g. mask value of 0x03 will allow only INPLACE and COMPUTE_TYPE reduction schemes.
|
| 1254 |
+
*
|
| 1255 |
+
* uint32_t, default: CUBLASLT_REDUCTION_SCHEME_MASK (allows all reduction schemes)
|
| 1256 |
+
*/
|
| 1257 |
+
CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3,
|
| 1258 |
+
|
| 1259 |
+
/** Gaussian mode mask, see cublasLt3mMode_t.
|
| 1260 |
+
*
|
| 1261 |
+
* Only algorithms with CUBLASLT_ALGO_CAP_GAUSSIAN_IMPL that is not masked out by this attribute are allowed.
|
| 1262 |
+
*
|
| 1263 |
+
* uint32_t, default: CUBLASLT_3M_MODE_ALLOWED (allows both gaussian and non-gaussian algorithms)
|
| 1264 |
+
* DEPRECATED, will be removed in a future release, see cublasLtNumericalImplFlags_t for replacement
|
| 1265 |
+
*/
|
| 1266 |
+
CUBLASLT_MATMUL_PREF_GAUSSIAN_MODE_MASK = 4,
|
| 1267 |
+
|
| 1268 |
+
/** Minimum buffer alignment for matrix A (in bytes).
|
| 1269 |
+
*
|
| 1270 |
+
* Selecting a smaller value will exclude algorithms that can not work with matrix A that is not as strictly aligned
|
| 1271 |
+
* as they need.
|
| 1272 |
+
*
|
| 1273 |
+
* uint32_t, default: 256
|
| 1274 |
+
*/
|
| 1275 |
+
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5,
|
| 1276 |
+
|
| 1277 |
+
/** Minimum buffer alignment for matrix B (in bytes).
|
| 1278 |
+
*
|
| 1279 |
+
* Selecting a smaller value will exclude algorithms that can not work with matrix B that is not as strictly aligned
|
| 1280 |
+
* as they need.
|
| 1281 |
+
*
|
| 1282 |
+
* uint32_t, default: 256
|
| 1283 |
+
*/
|
| 1284 |
+
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6,
|
| 1285 |
+
|
| 1286 |
+
/** Minimum buffer alignment for matrix C (in bytes).
|
| 1287 |
+
*
|
| 1288 |
+
* Selecting a smaller value will exclude algorithms that can not work with matrix C that is not as strictly aligned
|
| 1289 |
+
* as they need.
|
| 1290 |
+
*
|
| 1291 |
+
* uint32_t, default: 256
|
| 1292 |
+
*/
|
| 1293 |
+
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7,
|
| 1294 |
+
|
| 1295 |
+
/** Minimum buffer alignment for matrix D (in bytes).
|
| 1296 |
+
*
|
| 1297 |
+
* Selecting a smaller value will exclude algorithms that can not work with matrix D that is not as strictly aligned
|
| 1298 |
+
* as they need.
|
| 1299 |
+
*
|
| 1300 |
+
* uint32_t, default: 256
|
| 1301 |
+
*/
|
| 1302 |
+
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8,
|
| 1303 |
+
|
| 1304 |
+
/** Maximum wave count.
|
| 1305 |
+
*
|
| 1306 |
+
* See cublasLtMatmulHeuristicResult_t::wavesCount.
|
| 1307 |
+
*
|
| 1308 |
+
* Selecting a non-zero value will exclude algorithms that report device utilization higher than specified.
|
| 1309 |
+
*
|
| 1310 |
+
* float, default: 0.0f
|
| 1311 |
+
*/
|
| 1312 |
+
CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9,
|
| 1313 |
+
|
| 1314 |
+
/** Pointer mode mask, see cublasLtPointerModeMask_t. Filters heuristic result to only include algorithms that support
|
| 1315 |
+
* all required modes.
|
| 1316 |
+
*
|
| 1317 |
+
* uint32_t, default: (CUBLASLT_POINTER_MODE_MASK_HOST | CUBLASLT_POINTER_MODE_MASK_DEVICE) (only allows algorithms
|
| 1318 |
+
* that support both regular host and device pointers)
|
| 1319 |
+
*/
|
| 1320 |
+
CUBLASLT_MATMUL_PREF_POINTER_MODE_MASK = 10,
|
| 1321 |
+
|
| 1322 |
+
/** Epilogue selector mask, see cublasLtEpilogue_t. Filters heuristic result to only include algorithms that support
|
| 1323 |
+
* all required operations.
|
| 1324 |
+
*
|
| 1325 |
+
* uint32_t, default: CUBLASLT_EPILOGUE_DEFAULT (only allows algorithms that support default epilogue)
|
| 1326 |
+
*/
|
| 1327 |
+
CUBLASLT_MATMUL_PREF_EPILOGUE_MASK = 11,
|
| 1328 |
+
|
| 1329 |
+
/** Numerical implementation details mask, see cublasLtNumericalImplFlags_t. Filters heuristic result to only include
|
| 1330 |
+
* algorithms that use the allowed implementations.
|
| 1331 |
+
*
|
| 1332 |
+
* uint64_t, default: uint64_t(-1) (allow everything)
|
| 1333 |
+
*/
|
| 1334 |
+
CUBLASLT_MATMUL_PREF_IMPL_MASK = 12,
|
| 1335 |
+
|
| 1336 |
+
/** Number of SMs to target for parallel execution. Optimizes heuristics for execution on a different number of SMs
|
| 1337 |
+
* when user expects a concurrent stream to be using some of the device resources.
|
| 1338 |
+
*
|
| 1339 |
+
* Overrides the SM count target set in the matrix multiplication descriptor (see cublasLtMatmulDescAttributes_t).
|
| 1340 |
+
*
|
| 1341 |
+
* int32_t, default: 0 - use the number reported by the device.
|
| 1342 |
+
* DEPRECATED, will be removed in a future release, see cublasLtMatmulDescAttributes_t for replacement
|
| 1343 |
+
*/
|
| 1344 |
+
CUBLASLT_MATMUL_PREF_SM_COUNT_TARGET = 13,
|
| 1345 |
+
} cublasLtMatmulPreferenceAttributes_t;
|
| 1346 |
+
|
| 1347 |
+
/** Internal. Do not use directly.
|
| 1348 |
+
*/
|
| 1349 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size);
|
| 1350 |
+
|
| 1351 |
+
/** Initialize matmul heuristic search preference descriptor in pre-allocated space.
|
| 1352 |
+
*
|
| 1353 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient
|
| 1354 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 1355 |
+
*/
|
| 1356 |
+
static inline cublasStatus_t cublasLtMatmulPreferenceInit(cublasLtMatmulPreference_t pref) {
|
| 1357 |
+
return cublasLtMatmulPreferenceInit_internal(pref, sizeof(*pref));
|
| 1358 |
+
}
|
| 1359 |
+
|
| 1360 |
+
/** Create new matmul heuristic search preference descriptor.
|
| 1361 |
+
*
|
| 1362 |
+
* \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated
|
| 1363 |
+
* \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully
|
| 1364 |
+
*/
|
| 1365 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t* pref);
|
| 1366 |
+
|
| 1367 |
+
/** Destroy matmul heuristic search preference descriptor.
|
| 1368 |
+
*
|
| 1369 |
+
* \retval CUBLAS_STATUS_SUCCESS if operation was successful
|
| 1370 |
+
*/
|
| 1371 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref);
|
| 1372 |
+
|
| 1373 |
+
/** Set matmul heuristic search preference descriptor attribute.
|
| 1374 |
+
*
|
| 1375 |
+
* \param[in] pref The descriptor
|
| 1376 |
+
* \param[in] attr The attribute
|
| 1377 |
+
* \param[in] buf memory address containing the new value
|
| 1378 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1379 |
+
*
|
| 1380 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1381 |
+
* selected attribute
|
| 1382 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 1383 |
+
*/
|
| 1384 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( //
|
| 1385 |
+
cublasLtMatmulPreference_t pref,
|
| 1386 |
+
cublasLtMatmulPreferenceAttributes_t attr,
|
| 1387 |
+
const void* buf,
|
| 1388 |
+
size_t sizeInBytes);
|
| 1389 |
+
|
| 1390 |
+
/** Get matmul heuristic search preference descriptor attribute.
|
| 1391 |
+
*
|
| 1392 |
+
* \param[in] pref The descriptor
|
| 1393 |
+
* \param[in] attr The attribute
|
| 1394 |
+
* \param[out] buf memory address containing the new value
|
| 1395 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1396 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 1397 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 1398 |
+
*
|
| 1399 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 1400 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1401 |
+
* selected attribute
|
| 1402 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 1403 |
+
*/
|
| 1404 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( //
|
| 1405 |
+
cublasLtMatmulPreference_t pref,
|
| 1406 |
+
cublasLtMatmulPreferenceAttributes_t attr,
|
| 1407 |
+
void* buf,
|
| 1408 |
+
size_t sizeInBytes,
|
| 1409 |
+
size_t* sizeWritten);
|
| 1410 |
+
|
| 1411 |
+
/** Results structure used by cublasLtMatmulGetAlgo.
|
| 1412 |
+
*
|
| 1413 |
+
* Holds returned configured algo descriptor and its runtime properties.
|
| 1414 |
+
*/
|
| 1415 |
+
typedef struct {
|
| 1416 |
+
/** Matmul algorithm descriptor.
|
| 1417 |
+
*
|
| 1418 |
+
* Must be initialized with cublasLtMatmulAlgoInit() if preferences' CUBLASLT_MATMUL_PERF_SEARCH_MODE is set to
|
| 1419 |
+
* CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID
|
| 1420 |
+
*/
|
| 1421 |
+
cublasLtMatmulAlgo_t algo;
|
| 1422 |
+
|
| 1423 |
+
/** Actual size of workspace memory required.
|
| 1424 |
+
*/
|
| 1425 |
+
size_t workspaceSize;
|
| 1426 |
+
|
| 1427 |
+
/** Result status, other fields are only valid if after call to cublasLtMatmulAlgoGetHeuristic() this member is set to
|
| 1428 |
+
* CUBLAS_STATUS_SUCCESS.
|
| 1429 |
+
*/
|
| 1430 |
+
cublasStatus_t state;
|
| 1431 |
+
|
| 1432 |
+
/** Waves count - a device utilization metric.
|
| 1433 |
+
*
|
| 1434 |
+
* wavesCount value of 1.0f suggests that when kernel is launched it will fully occupy the GPU.
|
| 1435 |
+
*/
|
| 1436 |
+
float wavesCount;
|
| 1437 |
+
|
| 1438 |
+
int reserved[4];
|
| 1439 |
+
} cublasLtMatmulHeuristicResult_t;
|
| 1440 |
+
|
| 1441 |
+
/** Query cublasLt heuristic for algorithm appropriate for given use case.
|
| 1442 |
+
*
|
| 1443 |
+
* \param[in] lightHandle Pointer to the allocated cuBLASLt handle for the cuBLASLt
|
| 1444 |
+
* context. See cublasLtHandle_t.
|
| 1445 |
+
* \param[in] operationDesc Handle to the matrix multiplication descriptor.
|
| 1446 |
+
* \param[in] Adesc Handle to the layout descriptors for matrix A.
|
| 1447 |
+
* \param[in] Bdesc Handle to the layout descriptors for matrix B.
|
| 1448 |
+
* \param[in] Cdesc Handle to the layout descriptors for matrix C.
|
| 1449 |
+
* \param[in] Ddesc Handle to the layout descriptors for matrix D.
|
| 1450 |
+
* \param[in] preference Pointer to the structure holding the heuristic search
|
| 1451 |
+
* preferences descriptor. See cublasLtMatrixLayout_t.
|
| 1452 |
+
* \param[in] requestedAlgoCount Size of heuristicResultsArray (in elements) and requested
|
| 1453 |
+
* maximum number of algorithms to return.
|
| 1454 |
+
* \param[in, out] heuristicResultsArray Output algorithms and associated runtime characteristics,
|
| 1455 |
+
* ordered in increasing estimated compute time.
|
| 1456 |
+
* \param[out] returnAlgoCount The number of heuristicResultsArray elements written.
|
| 1457 |
+
*
|
| 1458 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero
|
| 1459 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if no heuristic function available for current configuration
|
| 1460 |
+
* \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect
|
| 1461 |
+
* heuristicResultsArray[0 to (returnAlgoCount - 1)].state
|
| 1462 |
+
* for detail status of results
|
| 1463 |
+
*/
|
| 1464 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic(cublasLtHandle_t lightHandle,
|
| 1465 |
+
cublasLtMatmulDesc_t operationDesc,
|
| 1466 |
+
cublasLtMatrixLayout_t Adesc,
|
| 1467 |
+
cublasLtMatrixLayout_t Bdesc,
|
| 1468 |
+
cublasLtMatrixLayout_t Cdesc,
|
| 1469 |
+
cublasLtMatrixLayout_t Ddesc,
|
| 1470 |
+
cublasLtMatmulPreference_t preference,
|
| 1471 |
+
int requestedAlgoCount,
|
| 1472 |
+
cublasLtMatmulHeuristicResult_t heuristicResultsArray[],
|
| 1473 |
+
int* returnAlgoCount);
|
| 1474 |
+
|
| 1475 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 1476 |
+
/* Lower level API to be able to implement own Heuristic and Find routines */
|
| 1477 |
+
/* ---------------------------------------------------------------------------------------*/
|
| 1478 |
+
|
| 1479 |
+
/** Routine to get all algo IDs that can potentially run
|
| 1480 |
+
*
|
| 1481 |
+
* \param[in] int requestedAlgoCount requested number of algos (must be less or equal to size of algoIdsA
|
| 1482 |
+
* (in elements)) \param[out] algoIdsA array to write algoIds to \param[out] returnAlgoCount number of algoIds
|
| 1483 |
+
* actually written
|
| 1484 |
+
*
|
| 1485 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero
|
| 1486 |
+
* \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect returnAlgoCount to get actual number of IDs
|
| 1487 |
+
* available
|
| 1488 |
+
*/
|
| 1489 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds(cublasLtHandle_t lightHandle,
|
| 1490 |
+
cublasComputeType_t computeType,
|
| 1491 |
+
cudaDataType_t scaleType,
|
| 1492 |
+
cudaDataType_t Atype,
|
| 1493 |
+
cudaDataType_t Btype,
|
| 1494 |
+
cudaDataType_t Ctype,
|
| 1495 |
+
cudaDataType_t Dtype,
|
| 1496 |
+
int requestedAlgoCount,
|
| 1497 |
+
int algoIdsArray[],
|
| 1498 |
+
int* returnAlgoCount);
|
| 1499 |
+
|
| 1500 |
+
/** Initialize algo structure
|
| 1501 |
+
*
|
| 1502 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if algo is NULL or algoId is outside of recognized range
|
| 1503 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if algoId is not supported for given combination of data types
|
| 1504 |
+
* \retval CUBLAS_STATUS_SUCCESS if the structure was successfully initialized
|
| 1505 |
+
*/
|
| 1506 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit(cublasLtHandle_t lightHandle,
|
| 1507 |
+
cublasComputeType_t computeType,
|
| 1508 |
+
cudaDataType_t scaleType,
|
| 1509 |
+
cudaDataType_t Atype,
|
| 1510 |
+
cudaDataType_t Btype,
|
| 1511 |
+
cudaDataType_t Ctype,
|
| 1512 |
+
cudaDataType_t Dtype,
|
| 1513 |
+
int algoId,
|
| 1514 |
+
cublasLtMatmulAlgo_t* algo);
|
| 1515 |
+
|
| 1516 |
+
/** Check configured algo descriptor for correctness and support on current device.
|
| 1517 |
+
*
|
| 1518 |
+
* Result includes required workspace size and calculated wave count.
|
| 1519 |
+
*
|
| 1520 |
+
* CUBLAS_STATUS_SUCCESS doesn't fully guarantee algo will run (will fail if e.g. buffers are not correctly aligned);
|
| 1521 |
+
* but if cublasLtMatmulAlgoCheck fails, the algo will not run.
|
| 1522 |
+
*
|
| 1523 |
+
* \param[in] algo algo configuration to check
|
| 1524 |
+
* \param[out] result result structure to report algo runtime characteristics; algo field is never updated
|
| 1525 |
+
*
|
| 1526 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if matrix layout descriptors or operation descriptor don't match algo
|
| 1527 |
+
* descriptor
|
| 1528 |
+
* \retval CUBLAS_STATUS_NOT_SUPPORTED if algo configuration or data type combination is not currently supported on
|
| 1529 |
+
* given device
|
| 1530 |
+
* \retval CUBLAS_STATUS_ARCH_MISMATCH if algo configuration cannot be run using the selected device
|
| 1531 |
+
* \retval CUBLAS_STATUS_SUCCESS if check was successful
|
| 1532 |
+
*/
|
| 1533 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( //
|
| 1534 |
+
cublasLtHandle_t lightHandle,
|
| 1535 |
+
cublasLtMatmulDesc_t operationDesc,
|
| 1536 |
+
cublasLtMatrixLayout_t Adesc,
|
| 1537 |
+
cublasLtMatrixLayout_t Bdesc,
|
| 1538 |
+
cublasLtMatrixLayout_t Cdesc,
|
| 1539 |
+
cublasLtMatrixLayout_t Ddesc,
|
| 1540 |
+
const cublasLtMatmulAlgo_t* algo, ///< may point to result->algo
|
| 1541 |
+
cublasLtMatmulHeuristicResult_t* result);
|
| 1542 |
+
|
| 1543 |
+
/** Capabilities Attributes that can be retrieved from an initialized Algo structure
|
| 1544 |
+
*/
|
| 1545 |
+
typedef enum {
|
| 1546 |
+
/** support for split K, see CUBLASLT_ALGO_CONFIG_SPLITK_NUM
|
| 1547 |
+
*
|
| 1548 |
+
* int32_t, 0 means no support, supported otherwise
|
| 1549 |
+
*/
|
| 1550 |
+
CUBLASLT_ALGO_CAP_SPLITK_SUPPORT = 0,
|
| 1551 |
+
/** reduction scheme mask, see cublasLtReductionScheme_t; shows supported reduction schemes, if reduction scheme is
|
| 1552 |
+
* not masked out it is supported.
|
| 1553 |
+
*
|
| 1554 |
+
* e.g. int isReductionSchemeComputeTypeSupported ? (reductionSchemeMask & CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE) ==
|
| 1555 |
+
* CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE ? 1 : 0;
|
| 1556 |
+
*
|
| 1557 |
+
* uint32_t
|
| 1558 |
+
*/
|
| 1559 |
+
CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK = 1,
|
| 1560 |
+
/** support for cta swizzling, see CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING
|
| 1561 |
+
*
|
| 1562 |
+
* uint32_t, 0 means no support, 1 means supported value of 1, other values are reserved
|
| 1563 |
+
*/
|
| 1564 |
+
CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT = 2,
|
| 1565 |
+
/** support strided batch
|
| 1566 |
+
*
|
| 1567 |
+
* int32_t, 0 means no support, supported otherwise
|
| 1568 |
+
*/
|
| 1569 |
+
CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT = 3,
|
| 1570 |
+
/** support results out of place (D != C in D = alpha.A.B + beta.C)
|
| 1571 |
+
*
|
| 1572 |
+
* int32_t, 0 means no support, supported otherwise
|
| 1573 |
+
*/
|
| 1574 |
+
CUBLASLT_ALGO_CAP_OUT_OF_PLACE_RESULT_SUPPORT = 4,
|
| 1575 |
+
/** syrk/herk support (on top of regular gemm)
|
| 1576 |
+
*
|
| 1577 |
+
* int32_t, 0 means no support, supported otherwise
|
| 1578 |
+
*/
|
| 1579 |
+
CUBLASLT_ALGO_CAP_UPLO_SUPPORT = 5,
|
| 1580 |
+
/** tile ids possible to use, see cublasLtMatmulTile_t; if no tile ids are supported use
|
| 1581 |
+
* CUBLASLT_MATMUL_TILE_UNDEFINED
|
| 1582 |
+
*
|
| 1583 |
+
* use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count
|
| 1584 |
+
*
|
| 1585 |
+
* array of uint32_t
|
| 1586 |
+
*/
|
| 1587 |
+
CUBLASLT_ALGO_CAP_TILE_IDS = 6,
|
| 1588 |
+
/** custom option range is from 0 to CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX (inclusive), see
|
| 1589 |
+
* CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION
|
| 1590 |
+
*
|
| 1591 |
+
* int32_t
|
| 1592 |
+
*/
|
| 1593 |
+
CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX = 7,
|
| 1594 |
+
/** whether algorithm is using regular compute or tensor operations
|
| 1595 |
+
*
|
| 1596 |
+
* int32_t 0 means regular compute, 1 means tensor operations;
|
| 1597 |
+
* DEPRECATED
|
| 1598 |
+
*/
|
| 1599 |
+
CUBLASLT_ALGO_CAP_MATHMODE_IMPL = 8,
|
| 1600 |
+
/** whether algorithm implements gaussian optimization of complex matrix multiplication, see cublasMath_t
|
| 1601 |
+
*
|
| 1602 |
+
* int32_t 0 means regular compute, 1 means gaussian;
|
| 1603 |
+
* DEPRECATED
|
| 1604 |
+
*/
|
| 1605 |
+
CUBLASLT_ALGO_CAP_GAUSSIAN_IMPL = 9,
|
| 1606 |
+
/** whether algorithm supports custom (not COL or ROW memory order), see cublasLtOrder_t
|
| 1607 |
+
*
|
| 1608 |
+
* int32_t 0 means only COL and ROW memory order is allowed, non-zero means that algo might have different
|
| 1609 |
+
* requirements;
|
| 1610 |
+
*/
|
| 1611 |
+
CUBLASLT_ALGO_CAP_CUSTOM_MEMORY_ORDER = 10,
|
| 1612 |
+
|
| 1613 |
+
/** bitmask enumerating pointer modes algorithm supports
|
| 1614 |
+
*
|
| 1615 |
+
* uint32_t, see cublasLtPointerModeMask_t
|
| 1616 |
+
*/
|
| 1617 |
+
CUBLASLT_ALGO_CAP_POINTER_MODE_MASK = 11,
|
| 1618 |
+
|
| 1619 |
+
/** bitmask enumerating kinds of postprocessing algorithm supports in the epilogue
|
| 1620 |
+
*
|
| 1621 |
+
* uint32_t, see cublasLtEpilogue_t
|
| 1622 |
+
*/
|
| 1623 |
+
CUBLASLT_ALGO_CAP_EPILOGUE_MASK = 12,
|
| 1624 |
+
/** stages ids possible to use, see cublasLtMatmulStages_t; if no stages ids are supported use
|
| 1625 |
+
* CUBLASLT_MATMUL_STAGES_UNDEFINED
|
| 1626 |
+
*
|
| 1627 |
+
* use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count
|
| 1628 |
+
*
|
| 1629 |
+
* array of uint32_t
|
| 1630 |
+
*/
|
| 1631 |
+
CUBLASLT_ALGO_CAP_STAGES_IDS = 13,
|
| 1632 |
+
/** support for nagative ld for all of the matrices
|
| 1633 |
+
*
|
| 1634 |
+
* int32_t 0 means no support, supported otherwise
|
| 1635 |
+
*/
|
| 1636 |
+
CUBLASLT_ALGO_CAP_LD_NEGATIVE = 14,
|
| 1637 |
+
/** details about algorithm's implementation that affect it's numerical behavior
|
| 1638 |
+
*
|
| 1639 |
+
* uint64_t, see cublasLtNumericalImplFlags_t
|
| 1640 |
+
*/
|
| 1641 |
+
CUBLASLT_ALGO_CAP_NUMERICAL_IMPL_FLAGS = 15,
|
| 1642 |
+
/** minimum alignment required for A matrix in bytes
|
| 1643 |
+
* (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
|
| 1644 |
+
*
|
| 1645 |
+
* uint32_t
|
| 1646 |
+
*/
|
| 1647 |
+
CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_A_BYTES = 16,
|
| 1648 |
+
/** minimum alignment required for B matrix in bytes
|
| 1649 |
+
* (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
|
| 1650 |
+
*
|
| 1651 |
+
* uint32_t
|
| 1652 |
+
*/
|
| 1653 |
+
CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_B_BYTES = 17,
|
| 1654 |
+
/** minimum alignment required for C matrix in bytes
|
| 1655 |
+
* (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
|
| 1656 |
+
*
|
| 1657 |
+
* uint32_t
|
| 1658 |
+
*/
|
| 1659 |
+
CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_C_BYTES = 18,
|
| 1660 |
+
/** minimum alignment required for D matrix in bytes
|
| 1661 |
+
* (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order)
|
| 1662 |
+
*
|
| 1663 |
+
* uint32_t
|
| 1664 |
+
*/
|
| 1665 |
+
CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_D_BYTES = 19,
|
| 1666 |
+
} cublasLtMatmulAlgoCapAttributes_t;
|
| 1667 |
+
|
| 1668 |
+
/** Get algo capability attribute.
|
| 1669 |
+
*
|
| 1670 |
+
* E.g. to get list of supported Tile IDs:
|
| 1671 |
+
* cublasLtMatmulTile_t tiles[CUBLASLT_MATMUL_TILE_END];
|
| 1672 |
+
* size_t num_tiles, size_written;
|
| 1673 |
+
* if (cublasLtMatmulAlgoCapGetAttribute(algo, CUBLASLT_ALGO_CAP_TILE_IDS, tiles, sizeof(tiles), size_written) ==
|
| 1674 |
+
* CUBLAS_STATUS_SUCCESS) { num_tiles = size_written / sizeof(tiles[0]);
|
| 1675 |
+
* }
|
| 1676 |
+
*
|
| 1677 |
+
* \param[in] algo The algo descriptor
|
| 1678 |
+
* \param[in] attr The attribute
|
| 1679 |
+
* \param[out] buf memory address containing the new value
|
| 1680 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1681 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 1682 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 1683 |
+
*
|
| 1684 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 1685 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1686 |
+
* selected attribute
|
| 1687 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 1688 |
+
*/
|
| 1689 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute(const cublasLtMatmulAlgo_t* algo,
|
| 1690 |
+
cublasLtMatmulAlgoCapAttributes_t attr,
|
| 1691 |
+
void* buf,
|
| 1692 |
+
size_t sizeInBytes,
|
| 1693 |
+
size_t* sizeWritten);
|
| 1694 |
+
|
| 1695 |
+
/** Algo Configuration Attributes that can be set according to the Algo capabilities
|
| 1696 |
+
*/
|
| 1697 |
+
typedef enum {
|
| 1698 |
+
/** algorithm index, see cublasLtMatmulAlgoGetIds()
|
| 1699 |
+
*
|
| 1700 |
+
* readonly, set by cublasLtMatmulAlgoInit()
|
| 1701 |
+
* int32_t
|
| 1702 |
+
*/
|
| 1703 |
+
CUBLASLT_ALGO_CONFIG_ID = 0,
|
| 1704 |
+
/** tile id, see cublasLtMatmulTile_t
|
| 1705 |
+
*
|
| 1706 |
+
* uint32_t, default: CUBLASLT_MATMUL_TILE_UNDEFINED
|
| 1707 |
+
*/
|
| 1708 |
+
CUBLASLT_ALGO_CONFIG_TILE_ID = 1,
|
| 1709 |
+
/** Number of K splits. If the number of K splits is greater than one, SPLITK_NUM parts
|
| 1710 |
+
* of matrix multiplication will be computed in parallel. The results will be accumulated
|
| 1711 |
+
* according to CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME
|
| 1712 |
+
*
|
| 1713 |
+
* int32_t, default: 1
|
| 1714 |
+
*/
|
| 1715 |
+
CUBLASLT_ALGO_CONFIG_SPLITK_NUM = 2,
|
| 1716 |
+
/** reduction scheme, see cublasLtReductionScheme_t
|
| 1717 |
+
*
|
| 1718 |
+
* uint32_t, default: CUBLASLT_REDUCTION_SCHEME_NONE
|
| 1719 |
+
*/
|
| 1720 |
+
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME = 3,
|
| 1721 |
+
/** cta swizzling, change mapping from CUDA grid coordinates to parts of the matrices
|
| 1722 |
+
*
|
| 1723 |
+
* possible values: 0, 1, other values reserved
|
| 1724 |
+
*
|
| 1725 |
+
* uint32_t, default: 0
|
| 1726 |
+
*/
|
| 1727 |
+
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING = 4,
|
| 1728 |
+
/** custom option, each algorithm can support some custom options that don't fit description of the other config
|
| 1729 |
+
* attributes, see CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX to get accepted range for any specific case
|
| 1730 |
+
*
|
| 1731 |
+
* uint32_t, default: 0
|
| 1732 |
+
*/
|
| 1733 |
+
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION = 5,
|
| 1734 |
+
/** stages id, see cublasLtMatmulStages_t
|
| 1735 |
+
*
|
| 1736 |
+
* uint32_t, default: CUBLASLT_MATMUL_STAGES_UNDEFINED
|
| 1737 |
+
*/
|
| 1738 |
+
CUBLASLT_ALGO_CONFIG_STAGES_ID = 6,
|
| 1739 |
+
/** inner shape id, see cublasLtMatmulInnerShape_t
|
| 1740 |
+
*
|
| 1741 |
+
* uint16_t, default: 0 (CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED)
|
| 1742 |
+
*/
|
| 1743 |
+
CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID = 7,
|
| 1744 |
+
/** Thread Block Cluster shape id, see cublasLtClusterShape_t. Defines cluster size to use.
|
| 1745 |
+
*
|
| 1746 |
+
* uint16_t, default: 0 (CUBLASLT_CLUSTER_SHAPE_AUTO)
|
| 1747 |
+
*/
|
| 1748 |
+
CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID = 8,
|
| 1749 |
+
} cublasLtMatmulAlgoConfigAttributes_t;
|
| 1750 |
+
|
| 1751 |
+
/** Set algo configuration attribute.
|
| 1752 |
+
*
|
| 1753 |
+
* \param[in] algo The algo descriptor
|
| 1754 |
+
* \param[in] attr The attribute
|
| 1755 |
+
* \param[in] buf memory address containing the new value
|
| 1756 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1757 |
+
*
|
| 1758 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1759 |
+
* selected attribute
|
| 1760 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully
|
| 1761 |
+
*/
|
| 1762 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute(cublasLtMatmulAlgo_t* algo,
|
| 1763 |
+
cublasLtMatmulAlgoConfigAttributes_t attr,
|
| 1764 |
+
const void* buf,
|
| 1765 |
+
size_t sizeInBytes);
|
| 1766 |
+
|
| 1767 |
+
/** Get algo configuration attribute.
|
| 1768 |
+
*
|
| 1769 |
+
* \param[in] algo The algo descriptor
|
| 1770 |
+
* \param[in] attr The attribute
|
| 1771 |
+
* \param[out] buf memory address containing the new value
|
| 1772 |
+
* \param[in] sizeInBytes size of buf buffer for verification (in bytes)
|
| 1773 |
+
* \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of
|
| 1774 |
+
* bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents
|
| 1775 |
+
*
|
| 1776 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero
|
| 1777 |
+
* and buf is NULL or sizeInBytes doesn't match size of internal storage for
|
| 1778 |
+
* selected attribute
|
| 1779 |
+
* \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory
|
| 1780 |
+
*/
|
| 1781 |
+
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute(const cublasLtMatmulAlgo_t* algo,
|
| 1782 |
+
cublasLtMatmulAlgoConfigAttributes_t attr,
|
| 1783 |
+
void* buf,
|
| 1784 |
+
size_t sizeInBytes,
|
| 1785 |
+
size_t* sizeWritten);
|
| 1786 |
+
|
| 1787 |
+
/** Experimental: Logger callback type.
|
| 1788 |
+
*/
|
| 1789 |
+
typedef void (*cublasLtLoggerCallback_t)(int logLevel, const char* functionName, const char* message);
|
| 1790 |
+
|
| 1791 |
+
/** Experimental: Logger callback setter.
|
| 1792 |
+
*
|
| 1793 |
+
* \param[in] callback a user defined callback function to be called by the logger
|
| 1794 |
+
*
|
| 1795 |
+
* \retval CUBLAS_STATUS_SUCCESS if callback was set successfully
|
| 1796 |
+
*/
|
| 1797 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerSetCallback(cublasLtLoggerCallback_t callback);
|
| 1798 |
+
|
| 1799 |
+
/** Experimental: Log file setter.
|
| 1800 |
+
*
|
| 1801 |
+
* \param[in] file an open file with write permissions
|
| 1802 |
+
*
|
| 1803 |
+
* \retval CUBLAS_STATUS_SUCCESS if log file was set successfully
|
| 1804 |
+
*/
|
| 1805 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerSetFile(FILE* file);
|
| 1806 |
+
|
| 1807 |
+
/** Experimental: Open log file.
|
| 1808 |
+
*
|
| 1809 |
+
* \param[in] logFile log file path. if the log file does not exist, it will be created
|
| 1810 |
+
*
|
| 1811 |
+
* \retval CUBLAS_STATUS_SUCCESS if log file was created successfully
|
| 1812 |
+
*/
|
| 1813 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerOpenFile(const char* logFile);
|
| 1814 |
+
|
| 1815 |
+
/** Experimental: Log level setter.
|
| 1816 |
+
*
|
| 1817 |
+
* \param[in] level log level, should be one of the following:
|
| 1818 |
+
* 0. Off
|
| 1819 |
+
* 1. Errors
|
| 1820 |
+
* 2. Performance Trace
|
| 1821 |
+
* 3. Performance Hints
|
| 1822 |
+
* 4. Heuristics Trace
|
| 1823 |
+
* 5. API Trace
|
| 1824 |
+
*
|
| 1825 |
+
* \retval CUBLAS_STATUS_INVALID_VALUE if log level is not one of the above levels
|
| 1826 |
+
*
|
| 1827 |
+
* \retval CUBLAS_STATUS_SUCCESS if log level was set successfully
|
| 1828 |
+
*/
|
| 1829 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerSetLevel(int level);
|
| 1830 |
+
|
| 1831 |
+
/** Experimental: Log mask setter.
|
| 1832 |
+
*
|
| 1833 |
+
* \param[in] mask log mask, should be a combination of the following masks:
|
| 1834 |
+
* 0. Off
|
| 1835 |
+
* 1. Errors
|
| 1836 |
+
* 2. Performance Trace
|
| 1837 |
+
* 4. Performance Hints
|
| 1838 |
+
* 8. Heuristics Trace
|
| 1839 |
+
* 16. API Trace
|
| 1840 |
+
*
|
| 1841 |
+
* \retval CUBLAS_STATUS_SUCCESS if log mask was set successfully
|
| 1842 |
+
*/
|
| 1843 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerSetMask(int mask);
|
| 1844 |
+
|
| 1845 |
+
/** Experimental: Disable logging for the entire session.
|
| 1846 |
+
*
|
| 1847 |
+
* \retval CUBLAS_STATUS_SUCCESS if disabled logging
|
| 1848 |
+
*/
|
| 1849 |
+
cublasStatus_t CUBLASWINAPI cublasLtLoggerForceDisable();
|
| 1850 |
+
|
| 1851 |
+
#if defined(__cplusplus)
|
| 1852 |
+
}
|
| 1853 |
+
#endif /* __cplusplus */
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_api.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/include/cublas_v2.h
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* This is the public header file for the new CUBLAS library API, it mapped the generic
|
| 52 |
+
* Cublas name functions to the actual _v2 implementations.
|
| 53 |
+
*/
|
| 54 |
+
|
| 55 |
+
#if !defined(CUBLAS_V2_H_)
|
| 56 |
+
#define CUBLAS_V2_H_
|
| 57 |
+
|
| 58 |
+
#undef CUBLASAPI
|
| 59 |
+
#ifdef __CUDACC__
|
| 60 |
+
#define CUBLASAPI __host__ __device__
|
| 61 |
+
#else
|
| 62 |
+
#define CUBLASAPI
|
| 63 |
+
#endif
|
| 64 |
+
|
| 65 |
+
#include "cublas_api.h"
|
| 66 |
+
|
| 67 |
+
#define cublasCreate cublasCreate_v2
|
| 68 |
+
#define cublasDestroy cublasDestroy_v2
|
| 69 |
+
#define cublasGetVersion cublasGetVersion_v2
|
| 70 |
+
#define cublasSetWorkspace cublasSetWorkspace_v2
|
| 71 |
+
#define cublasSetStream cublasSetStream_v2
|
| 72 |
+
#define cublasGetStream cublasGetStream_v2
|
| 73 |
+
#define cublasGetPointerMode cublasGetPointerMode_v2
|
| 74 |
+
#define cublasSetPointerMode cublasSetPointerMode_v2
|
| 75 |
+
|
| 76 |
+
/* Blas3 Routines */
|
| 77 |
+
|
| 78 |
+
#define cublasSnrm2 cublasSnrm2_v2
|
| 79 |
+
#define cublasDnrm2 cublasDnrm2_v2
|
| 80 |
+
#define cublasScnrm2 cublasScnrm2_v2
|
| 81 |
+
#define cublasDznrm2 cublasDznrm2_v2
|
| 82 |
+
|
| 83 |
+
#define cublasSdot cublasSdot_v2
|
| 84 |
+
#define cublasDdot cublasDdot_v2
|
| 85 |
+
#define cublasCdotu cublasCdotu_v2
|
| 86 |
+
#define cublasCdotc cublasCdotc_v2
|
| 87 |
+
#define cublasZdotu cublasZdotu_v2
|
| 88 |
+
#define cublasZdotc cublasZdotc_v2
|
| 89 |
+
|
| 90 |
+
#define cublasSscal cublasSscal_v2
|
| 91 |
+
#define cublasDscal cublasDscal_v2
|
| 92 |
+
#define cublasCscal cublasCscal_v2
|
| 93 |
+
#define cublasCsscal cublasCsscal_v2
|
| 94 |
+
#define cublasZscal cublasZscal_v2
|
| 95 |
+
#define cublasZdscal cublasZdscal_v2
|
| 96 |
+
|
| 97 |
+
#define cublasSaxpy cublasSaxpy_v2
|
| 98 |
+
#define cublasDaxpy cublasDaxpy_v2
|
| 99 |
+
#define cublasCaxpy cublasCaxpy_v2
|
| 100 |
+
#define cublasZaxpy cublasZaxpy_v2
|
| 101 |
+
|
| 102 |
+
#define cublasScopy cublasScopy_v2
|
| 103 |
+
#define cublasDcopy cublasDcopy_v2
|
| 104 |
+
#define cublasCcopy cublasCcopy_v2
|
| 105 |
+
#define cublasZcopy cublasZcopy_v2
|
| 106 |
+
|
| 107 |
+
#define cublasSswap cublasSswap_v2
|
| 108 |
+
#define cublasDswap cublasDswap_v2
|
| 109 |
+
#define cublasCswap cublasCswap_v2
|
| 110 |
+
#define cublasZswap cublasZswap_v2
|
| 111 |
+
|
| 112 |
+
#define cublasIsamax cublasIsamax_v2
|
| 113 |
+
#define cublasIdamax cublasIdamax_v2
|
| 114 |
+
#define cublasIcamax cublasIcamax_v2
|
| 115 |
+
#define cublasIzamax cublasIzamax_v2
|
| 116 |
+
|
| 117 |
+
#define cublasIsamin cublasIsamin_v2
|
| 118 |
+
#define cublasIdamin cublasIdamin_v2
|
| 119 |
+
#define cublasIcamin cublasIcamin_v2
|
| 120 |
+
#define cublasIzamin cublasIzamin_v2
|
| 121 |
+
|
| 122 |
+
#define cublasSasum cublasSasum_v2
|
| 123 |
+
#define cublasDasum cublasDasum_v2
|
| 124 |
+
#define cublasScasum cublasScasum_v2
|
| 125 |
+
#define cublasDzasum cublasDzasum_v2
|
| 126 |
+
|
| 127 |
+
#define cublasSrot cublasSrot_v2
|
| 128 |
+
#define cublasDrot cublasDrot_v2
|
| 129 |
+
#define cublasCrot cublasCrot_v2
|
| 130 |
+
#define cublasCsrot cublasCsrot_v2
|
| 131 |
+
#define cublasZrot cublasZrot_v2
|
| 132 |
+
#define cublasZdrot cublasZdrot_v2
|
| 133 |
+
|
| 134 |
+
#define cublasSrotg cublasSrotg_v2
|
| 135 |
+
#define cublasDrotg cublasDrotg_v2
|
| 136 |
+
#define cublasCrotg cublasCrotg_v2
|
| 137 |
+
#define cublasZrotg cublasZrotg_v2
|
| 138 |
+
|
| 139 |
+
#define cublasSrotm cublasSrotm_v2
|
| 140 |
+
#define cublasDrotm cublasDrotm_v2
|
| 141 |
+
|
| 142 |
+
#define cublasSrotmg cublasSrotmg_v2
|
| 143 |
+
#define cublasDrotmg cublasDrotmg_v2
|
| 144 |
+
|
| 145 |
+
/* Blas2 Routines */
|
| 146 |
+
|
| 147 |
+
#define cublasSgemv cublasSgemv_v2
|
| 148 |
+
#define cublasDgemv cublasDgemv_v2
|
| 149 |
+
#define cublasCgemv cublasCgemv_v2
|
| 150 |
+
#define cublasZgemv cublasZgemv_v2
|
| 151 |
+
|
| 152 |
+
#define cublasSgbmv cublasSgbmv_v2
|
| 153 |
+
#define cublasDgbmv cublasDgbmv_v2
|
| 154 |
+
#define cublasCgbmv cublasCgbmv_v2
|
| 155 |
+
#define cublasZgbmv cublasZgbmv_v2
|
| 156 |
+
|
| 157 |
+
#define cublasStrmv cublasStrmv_v2
|
| 158 |
+
#define cublasDtrmv cublasDtrmv_v2
|
| 159 |
+
#define cublasCtrmv cublasCtrmv_v2
|
| 160 |
+
#define cublasZtrmv cublasZtrmv_v2
|
| 161 |
+
|
| 162 |
+
#define cublasStbmv cublasStbmv_v2
|
| 163 |
+
#define cublasDtbmv cublasDtbmv_v2
|
| 164 |
+
#define cublasCtbmv cublasCtbmv_v2
|
| 165 |
+
#define cublasZtbmv cublasZtbmv_v2
|
| 166 |
+
|
| 167 |
+
#define cublasStpmv cublasStpmv_v2
|
| 168 |
+
#define cublasDtpmv cublasDtpmv_v2
|
| 169 |
+
#define cublasCtpmv cublasCtpmv_v2
|
| 170 |
+
#define cublasZtpmv cublasZtpmv_v2
|
| 171 |
+
|
| 172 |
+
#define cublasStrsv cublasStrsv_v2
|
| 173 |
+
#define cublasDtrsv cublasDtrsv_v2
|
| 174 |
+
#define cublasCtrsv cublasCtrsv_v2
|
| 175 |
+
#define cublasZtrsv cublasZtrsv_v2
|
| 176 |
+
|
| 177 |
+
#define cublasStpsv cublasStpsv_v2
|
| 178 |
+
#define cublasDtpsv cublasDtpsv_v2
|
| 179 |
+
#define cublasCtpsv cublasCtpsv_v2
|
| 180 |
+
#define cublasZtpsv cublasZtpsv_v2
|
| 181 |
+
|
| 182 |
+
#define cublasStbsv cublasStbsv_v2
|
| 183 |
+
#define cublasDtbsv cublasDtbsv_v2
|
| 184 |
+
#define cublasCtbsv cublasCtbsv_v2
|
| 185 |
+
#define cublasZtbsv cublasZtbsv_v2
|
| 186 |
+
|
| 187 |
+
#define cublasSsymv cublasSsymv_v2
|
| 188 |
+
#define cublasDsymv cublasDsymv_v2
|
| 189 |
+
#define cublasCsymv cublasCsymv_v2
|
| 190 |
+
#define cublasZsymv cublasZsymv_v2
|
| 191 |
+
#define cublasChemv cublasChemv_v2
|
| 192 |
+
#define cublasZhemv cublasZhemv_v2
|
| 193 |
+
|
| 194 |
+
#define cublasSsbmv cublasSsbmv_v2
|
| 195 |
+
#define cublasDsbmv cublasDsbmv_v2
|
| 196 |
+
#define cublasChbmv cublasChbmv_v2
|
| 197 |
+
#define cublasZhbmv cublasZhbmv_v2
|
| 198 |
+
|
| 199 |
+
#define cublasSspmv cublasSspmv_v2
|
| 200 |
+
#define cublasDspmv cublasDspmv_v2
|
| 201 |
+
#define cublasChpmv cublasChpmv_v2
|
| 202 |
+
#define cublasZhpmv cublasZhpmv_v2
|
| 203 |
+
|
| 204 |
+
#define cublasSger cublasSger_v2
|
| 205 |
+
#define cublasDger cublasDger_v2
|
| 206 |
+
#define cublasCgeru cublasCgeru_v2
|
| 207 |
+
#define cublasCgerc cublasCgerc_v2
|
| 208 |
+
#define cublasZgeru cublasZgeru_v2
|
| 209 |
+
#define cublasZgerc cublasZgerc_v2
|
| 210 |
+
|
| 211 |
+
#define cublasSsyr cublasSsyr_v2
|
| 212 |
+
#define cublasDsyr cublasDsyr_v2
|
| 213 |
+
#define cublasCsyr cublasCsyr_v2
|
| 214 |
+
#define cublasZsyr cublasZsyr_v2
|
| 215 |
+
#define cublasCher cublasCher_v2
|
| 216 |
+
#define cublasZher cublasZher_v2
|
| 217 |
+
|
| 218 |
+
#define cublasSspr cublasSspr_v2
|
| 219 |
+
#define cublasDspr cublasDspr_v2
|
| 220 |
+
#define cublasChpr cublasChpr_v2
|
| 221 |
+
#define cublasZhpr cublasZhpr_v2
|
| 222 |
+
|
| 223 |
+
#define cublasSsyr2 cublasSsyr2_v2
|
| 224 |
+
#define cublasDsyr2 cublasDsyr2_v2
|
| 225 |
+
#define cublasCsyr2 cublasCsyr2_v2
|
| 226 |
+
#define cublasZsyr2 cublasZsyr2_v2
|
| 227 |
+
#define cublasCher2 cublasCher2_v2
|
| 228 |
+
#define cublasZher2 cublasZher2_v2
|
| 229 |
+
|
| 230 |
+
#define cublasSspr2 cublasSspr2_v2
|
| 231 |
+
#define cublasDspr2 cublasDspr2_v2
|
| 232 |
+
#define cublasChpr2 cublasChpr2_v2
|
| 233 |
+
#define cublasZhpr2 cublasZhpr2_v2
|
| 234 |
+
|
| 235 |
+
/* Blas3 Routines */
|
| 236 |
+
|
| 237 |
+
#define cublasSgemm cublasSgemm_v2
|
| 238 |
+
#define cublasDgemm cublasDgemm_v2
|
| 239 |
+
#define cublasCgemm cublasCgemm_v2
|
| 240 |
+
#define cublasZgemm cublasZgemm_v2
|
| 241 |
+
|
| 242 |
+
#define cublasSsyrk cublasSsyrk_v2
|
| 243 |
+
#define cublasDsyrk cublasDsyrk_v2
|
| 244 |
+
#define cublasCsyrk cublasCsyrk_v2
|
| 245 |
+
#define cublasZsyrk cublasZsyrk_v2
|
| 246 |
+
#define cublasCherk cublasCherk_v2
|
| 247 |
+
#define cublasZherk cublasZherk_v2
|
| 248 |
+
|
| 249 |
+
#define cublasSsyr2k cublasSsyr2k_v2
|
| 250 |
+
#define cublasDsyr2k cublasDsyr2k_v2
|
| 251 |
+
#define cublasCsyr2k cublasCsyr2k_v2
|
| 252 |
+
#define cublasZsyr2k cublasZsyr2k_v2
|
| 253 |
+
#define cublasCher2k cublasCher2k_v2
|
| 254 |
+
#define cublasZher2k cublasZher2k_v2
|
| 255 |
+
|
| 256 |
+
#define cublasSsymm cublasSsymm_v2
|
| 257 |
+
#define cublasDsymm cublasDsymm_v2
|
| 258 |
+
#define cublasCsymm cublasCsymm_v2
|
| 259 |
+
#define cublasZsymm cublasZsymm_v2
|
| 260 |
+
#define cublasChemm cublasChemm_v2
|
| 261 |
+
#define cublasZhemm cublasZhemm_v2
|
| 262 |
+
|
| 263 |
+
#define cublasStrsm cublasStrsm_v2
|
| 264 |
+
#define cublasDtrsm cublasDtrsm_v2
|
| 265 |
+
#define cublasCtrsm cublasCtrsm_v2
|
| 266 |
+
#define cublasZtrsm cublasZtrsm_v2
|
| 267 |
+
|
| 268 |
+
#define cublasStrmm cublasStrmm_v2
|
| 269 |
+
#define cublasDtrmm cublasDtrmm_v2
|
| 270 |
+
#define cublasCtrmm cublasCtrmm_v2
|
| 271 |
+
#define cublasZtrmm cublasZtrmm_v2
|
| 272 |
+
|
| 273 |
+
#endif /* !defined(CUBLAS_V2_H_) */
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/include/cupti_pcsampling.h
ADDED
|
@@ -0,0 +1,923 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2020-2022 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
#if !defined(_CUPTI_PCSAMPLING_H_)
|
| 51 |
+
#define _CUPTI_PCSAMPLING_H_
|
| 52 |
+
|
| 53 |
+
#include <cuda.h>
|
| 54 |
+
#include <stdint.h>
|
| 55 |
+
#include <stddef.h>
|
| 56 |
+
#include "cupti_result.h"
|
| 57 |
+
|
| 58 |
+
#ifndef CUPTIAPI
|
| 59 |
+
#ifdef _WIN32
|
| 60 |
+
#define CUPTIAPI __stdcall
|
| 61 |
+
#else
|
| 62 |
+
#define CUPTIAPI
|
| 63 |
+
#endif
|
| 64 |
+
#endif
|
| 65 |
+
|
| 66 |
+
#define ACTIVITY_RECORD_ALIGNMENT 8
|
| 67 |
+
#if defined(_WIN32) // Windows 32- and 64-bit
|
| 68 |
+
#define START_PACKED_ALIGNMENT __pragma(pack(push,1)) // exact fit - no padding
|
| 69 |
+
#define PACKED_ALIGNMENT __declspec(align(ACTIVITY_RECORD_ALIGNMENT))
|
| 70 |
+
#define END_PACKED_ALIGNMENT __pragma(pack(pop))
|
| 71 |
+
#elif defined(__GNUC__) // GCC
|
| 72 |
+
#define START_PACKED_ALIGNMENT
|
| 73 |
+
#define PACKED_ALIGNMENT __attribute__ ((__packed__)) __attribute__ ((aligned (ACTIVITY_RECORD_ALIGNMENT)))
|
| 74 |
+
#define END_PACKED_ALIGNMENT
|
| 75 |
+
#else // all other compilers
|
| 76 |
+
#define START_PACKED_ALIGNMENT
|
| 77 |
+
#define PACKED_ALIGNMENT
|
| 78 |
+
#define END_PACKED_ALIGNMENT
|
| 79 |
+
#endif
|
| 80 |
+
|
| 81 |
+
#if defined(__cplusplus)
|
| 82 |
+
extern "C" {
|
| 83 |
+
#endif
|
| 84 |
+
|
| 85 |
+
#if defined(__GNUC__) && defined(CUPTI_LIB)
|
| 86 |
+
#pragma GCC visibility push(default)
|
| 87 |
+
#endif
|
| 88 |
+
|
| 89 |
+
/**
|
| 90 |
+
* \defgroup CUPTI_PCSAMPLING_API CUPTI PC Sampling API
|
| 91 |
+
* Functions, types, and enums that implement the CUPTI PC Sampling API.
|
| 92 |
+
* @{
|
| 93 |
+
*/
|
| 94 |
+
|
| 95 |
+
#ifndef CUPTI_PCSAMPLING_STRUCT_SIZE
|
| 96 |
+
#define CUPTI_PCSAMPLING_STRUCT_SIZE(type_, lastfield_) (offsetof(type_, lastfield_) + sizeof(((type_*)0)->lastfield_))
|
| 97 |
+
#endif
|
| 98 |
+
|
| 99 |
+
#ifndef CUPTI_STALL_REASON_STRING_SIZE
|
| 100 |
+
#define CUPTI_STALL_REASON_STRING_SIZE 128
|
| 101 |
+
#endif
|
| 102 |
+
|
| 103 |
+
/**
|
| 104 |
+
* \brief PC Sampling collection mode
|
| 105 |
+
*/
|
| 106 |
+
typedef enum
|
| 107 |
+
{
|
| 108 |
+
/**
|
| 109 |
+
* INVALID Value
|
| 110 |
+
*/
|
| 111 |
+
CUPTI_PC_SAMPLING_COLLECTION_MODE_INVALID = 0,
|
| 112 |
+
/**
|
| 113 |
+
* Continuous mode. Kernels are not serialized in this mode.
|
| 114 |
+
*/
|
| 115 |
+
CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS = 1,
|
| 116 |
+
/**
|
| 117 |
+
* Serialized mode. Kernels are serialized in this mode.
|
| 118 |
+
*/
|
| 119 |
+
CUPTI_PC_SAMPLING_COLLECTION_MODE_KERNEL_SERIALIZED = 2,
|
| 120 |
+
} CUpti_PCSamplingCollectionMode;
|
| 121 |
+
|
| 122 |
+
/**
|
| 123 |
+
* \brief PC Sampling stall reasons
|
| 124 |
+
*/
|
| 125 |
+
typedef struct PACKED_ALIGNMENT
|
| 126 |
+
{
|
| 127 |
+
/**
|
| 128 |
+
* [r] Collected stall reason index
|
| 129 |
+
*/
|
| 130 |
+
uint32_t pcSamplingStallReasonIndex;
|
| 131 |
+
/**
|
| 132 |
+
* [r] Number of times the PC was sampled with the stallReason.
|
| 133 |
+
*/
|
| 134 |
+
uint32_t samples;
|
| 135 |
+
} CUpti_PCSamplingStallReason;
|
| 136 |
+
|
| 137 |
+
/**
|
| 138 |
+
* \brief PC Sampling data
|
| 139 |
+
*/
|
| 140 |
+
typedef struct PACKED_ALIGNMENT
|
| 141 |
+
{
|
| 142 |
+
/**
|
| 143 |
+
* [w] Size of the data structure.
|
| 144 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 145 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 146 |
+
*/
|
| 147 |
+
size_t size;
|
| 148 |
+
/**
|
| 149 |
+
* [r] Unique cubin id
|
| 150 |
+
*/
|
| 151 |
+
uint64_t cubinCrc;
|
| 152 |
+
/**
|
| 153 |
+
* [r] PC offset
|
| 154 |
+
*/
|
| 155 |
+
uint64_t pcOffset;
|
| 156 |
+
/**
|
| 157 |
+
* The function's unique symbol index in the module.
|
| 158 |
+
*/
|
| 159 |
+
uint32_t functionIndex;
|
| 160 |
+
/**
|
| 161 |
+
* Padding
|
| 162 |
+
*/
|
| 163 |
+
uint32_t pad;
|
| 164 |
+
/**
|
| 165 |
+
* [r] The function name. This name string might be shared across all the records
|
| 166 |
+
* including records from activity APIs representing the same function, and so it should not be
|
| 167 |
+
* modified or freed until post processing of all the records is done. Once done, it is user’s responsibility to
|
| 168 |
+
* free the memory using free() function.
|
| 169 |
+
*/
|
| 170 |
+
char* functionName;
|
| 171 |
+
/**
|
| 172 |
+
* [r] Collected stall reason count
|
| 173 |
+
*/
|
| 174 |
+
size_t stallReasonCount;
|
| 175 |
+
/**
|
| 176 |
+
* [r] Stall reason id
|
| 177 |
+
* Total samples
|
| 178 |
+
*/
|
| 179 |
+
CUpti_PCSamplingStallReason *stallReason;
|
| 180 |
+
} CUpti_PCSamplingPCData;
|
| 181 |
+
|
| 182 |
+
/**
|
| 183 |
+
* \brief PC Sampling output data format
|
| 184 |
+
*/
|
| 185 |
+
typedef enum
|
| 186 |
+
{
|
| 187 |
+
CUPTI_PC_SAMPLING_OUTPUT_DATA_FORMAT_INVALID = 0,
|
| 188 |
+
/**
|
| 189 |
+
* HW buffer data will be parsed during collection of data
|
| 190 |
+
*/
|
| 191 |
+
CUPTI_PC_SAMPLING_OUTPUT_DATA_FORMAT_PARSED = 1,
|
| 192 |
+
} CUpti_PCSamplingOutputDataFormat;
|
| 193 |
+
|
| 194 |
+
/**
|
| 195 |
+
* \brief Collected PC Sampling data
|
| 196 |
+
*
|
| 197 |
+
*/
|
| 198 |
+
typedef struct PACKED_ALIGNMENT
|
| 199 |
+
{
|
| 200 |
+
/**
|
| 201 |
+
* [w] Size of the data structure.
|
| 202 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 203 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 204 |
+
*/
|
| 205 |
+
size_t size;
|
| 206 |
+
/**
|
| 207 |
+
* [w] Number of PCs to be collected
|
| 208 |
+
*/
|
| 209 |
+
size_t collectNumPcs;
|
| 210 |
+
/**
|
| 211 |
+
* [r] Number of samples collected across all PCs.
|
| 212 |
+
* It includes samples for user modules, samples for non-user kernels and dropped samples.
|
| 213 |
+
* It includes counts for all non selected stall reasons.
|
| 214 |
+
* CUPTI does not provide PC records for non-user kernels.
|
| 215 |
+
* CUPTI does not provide PC records for instructions for which all selected stall reason metrics counts are zero.
|
| 216 |
+
*/
|
| 217 |
+
uint64_t totalSamples;
|
| 218 |
+
/**
|
| 219 |
+
* [r] Number of samples that were dropped by hardware due to backpressure/overflow.
|
| 220 |
+
*/
|
| 221 |
+
uint64_t droppedSamples;
|
| 222 |
+
/**
|
| 223 |
+
* [r] Number of PCs collected
|
| 224 |
+
*/
|
| 225 |
+
size_t totalNumPcs;
|
| 226 |
+
/**
|
| 227 |
+
* [r] Number of PCs available for collection
|
| 228 |
+
*/
|
| 229 |
+
size_t remainingNumPcs;
|
| 230 |
+
/**
|
| 231 |
+
* [r] Unique identifier for each range.
|
| 232 |
+
* Data collected across multiple ranges in multiple buffers can be identified using range id.
|
| 233 |
+
*/
|
| 234 |
+
uint64_t rangeId;
|
| 235 |
+
/**
|
| 236 |
+
* [r] Profiled PC data
|
| 237 |
+
* This data struct should have enough memory to collect number of PCs mentioned in \brief collectNumPcs
|
| 238 |
+
*/
|
| 239 |
+
CUpti_PCSamplingPCData *pPcData;
|
| 240 |
+
/**
|
| 241 |
+
* [r] Number of samples collected across all non user kernels PCs.
|
| 242 |
+
* It includes samples for non-user kernels.
|
| 243 |
+
* It includes counts for all non selected stall reasons as well.
|
| 244 |
+
* CUPTI does not provide PC records for non-user kernels.
|
| 245 |
+
*/
|
| 246 |
+
uint64_t nonUsrKernelsTotalSamples;
|
| 247 |
+
} CUpti_PCSamplingData;
|
| 248 |
+
|
| 249 |
+
/**
|
| 250 |
+
* \brief PC Sampling configuration attributes
|
| 251 |
+
*
|
| 252 |
+
* PC Sampling configuration attribute types. These attributes can be read
|
| 253 |
+
* using \ref cuptiPCSamplingGetConfigurationAttribute and can be written
|
| 254 |
+
* using \ref cuptiPCSamplingSetConfigurationAttribute. Attributes marked
|
| 255 |
+
* [r] can only be read using \ref cuptiPCSamplingGetConfigurationAttribute
|
| 256 |
+
* [w] can only be written using \ref cuptiPCSamplingSetConfigurationAttribute
|
| 257 |
+
* [rw] can be read using \ref cuptiPCSamplingGetConfigurationAttribute and
|
| 258 |
+
* written using \ref cuptiPCSamplingSetConfigurationAttribute
|
| 259 |
+
*/
|
| 260 |
+
typedef enum
|
| 261 |
+
{
|
| 262 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_INVALID = 0,
|
| 263 |
+
/**
|
| 264 |
+
* [rw] Sampling period for PC Sampling.
|
| 265 |
+
* DEFAULT - CUPTI defined value based on number of SMs
|
| 266 |
+
* Valid values for the sampling
|
| 267 |
+
* periods are between 5 to 31 both inclusive. This will set the
|
| 268 |
+
* sampling period to (2^samplingPeriod) cycles.
|
| 269 |
+
* For e.g. for sampling period = 5 to 31, cycles = 32, 64, 128,..., 2^31
|
| 270 |
+
* Value is a uint32_t
|
| 271 |
+
*/
|
| 272 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_PERIOD = 1,
|
| 273 |
+
/**
|
| 274 |
+
* [w] Number of stall reasons to collect.
|
| 275 |
+
* DEFAULT - All stall reasons will be collected
|
| 276 |
+
* Value is a size_t
|
| 277 |
+
* [w] Stall reasons to collect
|
| 278 |
+
* DEFAULT - All stall reasons will be collected
|
| 279 |
+
* Input value should be a pointer pointing to array of stall reason indexes
|
| 280 |
+
* containing all the stall reason indexes to collect.
|
| 281 |
+
*/
|
| 282 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON = 2,
|
| 283 |
+
/**
|
| 284 |
+
* [rw] Size of SW buffer for raw PC counter data downloaded from HW buffer
|
| 285 |
+
* DEFAULT - 1 MB, which can accommodate approximately 5500 PCs
|
| 286 |
+
* with all stall reasons
|
| 287 |
+
* Approximately it takes 16 Bytes (and some fixed size memory)
|
| 288 |
+
* to accommodate one PC with one stall reason
|
| 289 |
+
* For e.g. 1 PC with 1 stall reason = 32 Bytes
|
| 290 |
+
* 1 PC with 2 stall reason = 48 Bytes
|
| 291 |
+
* 1 PC with 4 stall reason = 96 Bytes
|
| 292 |
+
* Value is a size_t
|
| 293 |
+
*/
|
| 294 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SCRATCH_BUFFER_SIZE = 3,
|
| 295 |
+
/**
|
| 296 |
+
* [rw] Size of HW buffer in bytes
|
| 297 |
+
* DEFAULT - 512 MB
|
| 298 |
+
* If sampling period is too less, HW buffer can overflow
|
| 299 |
+
* and drop PC data
|
| 300 |
+
* Value is a size_t
|
| 301 |
+
*/
|
| 302 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_HARDWARE_BUFFER_SIZE = 4,
|
| 303 |
+
/**
|
| 304 |
+
* [rw] PC Sampling collection mode
|
| 305 |
+
* DEFAULT - CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS
|
| 306 |
+
* Input value should be of type \ref CUpti_PCSamplingCollectionMode.
|
| 307 |
+
*/
|
| 308 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_COLLECTION_MODE = 5,
|
| 309 |
+
/**
|
| 310 |
+
* [rw] Control over PC Sampling data collection range
|
| 311 |
+
* Default - 0
|
| 312 |
+
* 1 - Allows user to start and stop PC Sampling using APIs -
|
| 313 |
+
* \ref cuptiPCSamplingStart() - Start PC Sampling
|
| 314 |
+
* \ref cuptiPCSamplingStop() - Stop PC Sampling
|
| 315 |
+
* Value is a uint32_t
|
| 316 |
+
*/
|
| 317 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL = 6,
|
| 318 |
+
/**
|
| 319 |
+
* [w] Value for output data format
|
| 320 |
+
* Default - CUPTI_PC_SAMPLING_OUTPUT_DATA_FORMAT_PARSED
|
| 321 |
+
* Input value should be of type \ref CUpti_PCSamplingOutputDataFormat.
|
| 322 |
+
*/
|
| 323 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_OUTPUT_DATA_FORMAT = 7,
|
| 324 |
+
/**
|
| 325 |
+
* [w] Data buffer to hold collected PC Sampling data PARSED_DATA
|
| 326 |
+
* Default - none.
|
| 327 |
+
* Buffer type is void * which can point to PARSED_DATA
|
| 328 |
+
* Refer \ref CUpti_PCSamplingData for buffer format for PARSED_DATA
|
| 329 |
+
*/
|
| 330 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_DATA_BUFFER = 8,
|
| 331 |
+
CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_FORCE_INT = 0x7fffffff,
|
| 332 |
+
} CUpti_PCSamplingConfigurationAttributeType;
|
| 333 |
+
|
| 334 |
+
/**
|
| 335 |
+
* \brief PC sampling configuration information structure
|
| 336 |
+
*
|
| 337 |
+
* This structure provides \ref CUpti_PCSamplingConfigurationAttributeType which can be configured
|
| 338 |
+
* or queried for PC sampling configuration
|
| 339 |
+
*/
|
| 340 |
+
typedef struct
|
| 341 |
+
{
|
| 342 |
+
/**
|
| 343 |
+
* Refer \ref CUpti_PCSamplingConfigurationAttributeType for all supported attribute types
|
| 344 |
+
*/
|
| 345 |
+
CUpti_PCSamplingConfigurationAttributeType attributeType;
|
| 346 |
+
/*
|
| 347 |
+
* Configure or query status for \p attributeType
|
| 348 |
+
* CUPTI_SUCCESS for valid \p attributeType and \p attributeData
|
| 349 |
+
* CUPTI_ERROR_INVALID_OPERATION if \p attributeData is not valid
|
| 350 |
+
* CUPTI_ERROR_INVALID_PARAMETER if \p attributeType is not valid
|
| 351 |
+
*/
|
| 352 |
+
CUptiResult attributeStatus;
|
| 353 |
+
union
|
| 354 |
+
{
|
| 355 |
+
/**
|
| 356 |
+
* Invalid Value
|
| 357 |
+
*/
|
| 358 |
+
struct
|
| 359 |
+
{
|
| 360 |
+
uint64_t data[3];
|
| 361 |
+
} invalidData;
|
| 362 |
+
/**
|
| 363 |
+
* Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_PERIOD
|
| 364 |
+
*/
|
| 365 |
+
struct
|
| 366 |
+
{
|
| 367 |
+
uint32_t samplingPeriod;
|
| 368 |
+
} samplingPeriodData;
|
| 369 |
+
/**
|
| 370 |
+
* Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON
|
| 371 |
+
*/
|
| 372 |
+
struct
|
| 373 |
+
{
|
| 374 |
+
size_t stallReasonCount;
|
| 375 |
+
uint32_t *pStallReasonIndex;
|
| 376 |
+
} stallReasonData;
|
| 377 |
+
/**
|
| 378 |
+
* Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SCRATCH_BUFFER_SIZE
|
| 379 |
+
*/
|
| 380 |
+
struct
|
| 381 |
+
{
|
| 382 |
+
size_t scratchBufferSize;
|
| 383 |
+
} scratchBufferSizeData;
|
| 384 |
+
/**
|
| 385 |
+
* Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_HARDWARE_BUFFER_SIZE
|
| 386 |
+
*/
|
| 387 |
+
struct
|
| 388 |
+
{
|
| 389 |
+
size_t hardwareBufferSize;
|
| 390 |
+
} hardwareBufferSizeData;
|
| 391 |
+
/**
|
| 392 |
+
* Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_COLLECTION_MODE
|
| 393 |
+
*/
|
| 394 |
+
struct
|
| 395 |
+
{
|
| 396 |
+
CUpti_PCSamplingCollectionMode collectionMode;
|
| 397 |
+
} collectionModeData;
|
| 398 |
+
/**
|
| 399 |
+
* Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL
|
| 400 |
+
*/
|
| 401 |
+
struct
|
| 402 |
+
{
|
| 403 |
+
uint32_t enableStartStopControl;
|
| 404 |
+
} enableStartStopControlData;
|
| 405 |
+
/**
|
| 406 |
+
* Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_OUTPUT_DATA_FORMAT
|
| 407 |
+
*/
|
| 408 |
+
struct
|
| 409 |
+
{
|
| 410 |
+
CUpti_PCSamplingOutputDataFormat outputDataFormat;
|
| 411 |
+
} outputDataFormatData;
|
| 412 |
+
/**
|
| 413 |
+
* Refer \ref CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_DATA_BUFFER
|
| 414 |
+
*/
|
| 415 |
+
struct
|
| 416 |
+
{
|
| 417 |
+
void *samplingDataBuffer;
|
| 418 |
+
} samplingDataBufferData;
|
| 419 |
+
} attributeData;
|
| 420 |
+
} CUpti_PCSamplingConfigurationInfo;
|
| 421 |
+
|
| 422 |
+
/**
|
| 423 |
+
* \brief PC sampling configuration structure
|
| 424 |
+
*
|
| 425 |
+
* This structure configures PC sampling using \ref cuptiPCSamplingSetConfigurationAttribute
|
| 426 |
+
* and queries PC sampling default configuration using \ref cuptiPCSamplingGetConfigurationAttribute
|
| 427 |
+
*/
|
| 428 |
+
typedef struct
|
| 429 |
+
{
|
| 430 |
+
/**
|
| 431 |
+
* [w] Size of the data structure i.e. CUpti_PCSamplingConfigurationInfoParamsSize
|
| 432 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 433 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 434 |
+
*/
|
| 435 |
+
size_t size;
|
| 436 |
+
/**
|
| 437 |
+
* [w] Assign to NULL
|
| 438 |
+
*/
|
| 439 |
+
void* pPriv;
|
| 440 |
+
/**
|
| 441 |
+
* [w] CUcontext
|
| 442 |
+
*/
|
| 443 |
+
CUcontext ctx;
|
| 444 |
+
/**
|
| 445 |
+
* [w] Number of attributes to configure using \ref cuptiPCSamplingSetConfigurationAttribute or query
|
| 446 |
+
* using \ref cuptiPCSamplingGetConfigurationAttribute
|
| 447 |
+
*/
|
| 448 |
+
size_t numAttributes;
|
| 449 |
+
/**
|
| 450 |
+
* Refer \ref CUpti_PCSamplingConfigurationInfo
|
| 451 |
+
*/
|
| 452 |
+
CUpti_PCSamplingConfigurationInfo *pPCSamplingConfigurationInfo;
|
| 453 |
+
} CUpti_PCSamplingConfigurationInfoParams;
|
| 454 |
+
#define CUpti_PCSamplingConfigurationInfoParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingConfigurationInfoParams,pPCSamplingConfigurationInfo)
|
| 455 |
+
|
| 456 |
+
/**
|
| 457 |
+
* \brief Write PC Sampling configuration attribute.
|
| 458 |
+
*
|
| 459 |
+
* \param pParams A pointer to \ref CUpti_PCSamplingConfigurationInfoParams
|
| 460 |
+
* containing PC sampling configuration.
|
| 461 |
+
*
|
| 462 |
+
* \retval CUPTI_SUCCESS
|
| 463 |
+
* \retval CUPTI_ERROR_INVALID_OPERATION if this API is called with
|
| 464 |
+
* some invalid \p attrib.
|
| 465 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if attribute \p value is not valid
|
| 466 |
+
* or any \p pParams is not valid
|
| 467 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 468 |
+
* does not support the API
|
| 469 |
+
*/
|
| 470 |
+
CUptiResult CUPTIAPI cuptiPCSamplingSetConfigurationAttribute(CUpti_PCSamplingConfigurationInfoParams *pParams);
|
| 471 |
+
|
| 472 |
+
/**
|
| 473 |
+
* \brief Read PC Sampling configuration attribute.
|
| 474 |
+
*
|
| 475 |
+
* \param pParams A pointer to \ref CUpti_PCSamplingConfigurationInfoParams
|
| 476 |
+
* containing PC sampling configuration.
|
| 477 |
+
*
|
| 478 |
+
* \retval CUPTI_SUCCESS
|
| 479 |
+
* \retval CUPTI_ERROR_INVALID_OPERATION if this API is called with
|
| 480 |
+
* some invalid attribute.
|
| 481 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if \p attrib is not valid
|
| 482 |
+
* or any \p pParams is not valid
|
| 483 |
+
* \retval CUPTI_ERROR_PARAMETER_SIZE_NOT_SUFFICIENT indicates that
|
| 484 |
+
* the \p value buffer is too small to hold the attribute value
|
| 485 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 486 |
+
* does not support the API
|
| 487 |
+
*/
|
| 488 |
+
CUptiResult CUPTIAPI cuptiPCSamplingGetConfigurationAttribute(CUpti_PCSamplingConfigurationInfoParams *pParams);
|
| 489 |
+
|
| 490 |
+
/**
|
| 491 |
+
* \brief Params for cuptiPCSamplingEnable
|
| 492 |
+
*/
|
| 493 |
+
typedef struct
|
| 494 |
+
{
|
| 495 |
+
/**
|
| 496 |
+
* [w] Size of the data structure i.e. CUpti_PCSamplingGetDataParamsSize
|
| 497 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 498 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 499 |
+
*/
|
| 500 |
+
size_t size;
|
| 501 |
+
/**
|
| 502 |
+
* [w] Assign to NULL
|
| 503 |
+
*/
|
| 504 |
+
void* pPriv;
|
| 505 |
+
/**
|
| 506 |
+
* [w] CUcontext
|
| 507 |
+
*/
|
| 508 |
+
CUcontext ctx;
|
| 509 |
+
/**
|
| 510 |
+
* \param pcSamplingData Data buffer to hold collected PC Sampling data PARSED_DATA
|
| 511 |
+
* Buffer type is void * which can point to PARSED_DATA
|
| 512 |
+
* Refer \ref CUpti_PCSamplingData for buffer format for PARSED_DATA
|
| 513 |
+
*/
|
| 514 |
+
void *pcSamplingData;
|
| 515 |
+
} CUpti_PCSamplingGetDataParams;
|
| 516 |
+
#define CUpti_PCSamplingGetDataParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingGetDataParams, pcSamplingData)
|
| 517 |
+
/**
|
| 518 |
+
* \brief Flush GPU PC sampling data periodically.
|
| 519 |
+
*
|
| 520 |
+
* Flushing of GPU PC Sampling data is required at following point to maintain uniqueness of PCs:
|
| 521 |
+
* For \brief CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS, after every module load-unload-load
|
| 522 |
+
* For \brief CUPTI_PC_SAMPLING_COLLECTION_MODE_KERNEL_SERIALIZED, after every kernel ends
|
| 523 |
+
* If configuration option \brief CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL
|
| 524 |
+
* is enabled, then after every range end i.e. \brief cuptiPCSamplingStop()
|
| 525 |
+
*
|
| 526 |
+
* If application is profiled in \brief CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS, with disabled
|
| 527 |
+
* \brief CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL, and there is no module unload,
|
| 528 |
+
* user can collect data in two ways:
|
| 529 |
+
* Use \brief cuptiPCSamplingGetData() API periodically
|
| 530 |
+
* Use \brief cuptiPCSamplingDisable() on application exit and read GPU PC sampling data from sampling
|
| 531 |
+
* data buffer passed during configuration.
|
| 532 |
+
* Note: In case, \brief cuptiPCSamplingGetData() API is not called periodically, then sampling data buffer
|
| 533 |
+
* passed during configuration should be large enough to hold all PCs data.
|
| 534 |
+
* \brief cuptiPCSamplingGetData() API never does device synchronization.
|
| 535 |
+
* It is possible that when the API is called there is some unconsumed data from the HW buffer. In this case
|
| 536 |
+
* CUPTI provides only the data available with it at that moment.
|
| 537 |
+
*
|
| 538 |
+
* \param Refer \ref CUpti_PCSamplingGetDataParams
|
| 539 |
+
*
|
| 540 |
+
* \retval CUPTI_SUCCESS
|
| 541 |
+
* \retval CUPTI_ERROR_INVALID_OPERATION if this API is called without
|
| 542 |
+
* enabling PC sampling.
|
| 543 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
|
| 544 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 545 |
+
* does not support the API
|
| 546 |
+
*/
|
| 547 |
+
CUptiResult CUPTIAPI cuptiPCSamplingGetData(CUpti_PCSamplingGetDataParams *pParams);
|
| 548 |
+
|
| 549 |
+
/**
|
| 550 |
+
* \brief Params for cuptiPCSamplingEnable
|
| 551 |
+
*/
|
| 552 |
+
typedef struct
|
| 553 |
+
{
|
| 554 |
+
/**
|
| 555 |
+
* [w] Size of the data structure i.e. CUpti_PCSamplingEnableParamsSize
|
| 556 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 557 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 558 |
+
*/
|
| 559 |
+
size_t size;
|
| 560 |
+
/**
|
| 561 |
+
* [w] Assign to NULL
|
| 562 |
+
*/
|
| 563 |
+
void* pPriv;
|
| 564 |
+
/**
|
| 565 |
+
* [w] CUcontext
|
| 566 |
+
*/
|
| 567 |
+
CUcontext ctx;
|
| 568 |
+
} CUpti_PCSamplingEnableParams;
|
| 569 |
+
#define CUpti_PCSamplingEnableParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingEnableParams, ctx)
|
| 570 |
+
|
| 571 |
+
/**
|
| 572 |
+
* \brief Enable PC sampling.
|
| 573 |
+
*
|
| 574 |
+
* \param Refer \ref CUpti_PCSamplingEnableParams
|
| 575 |
+
*
|
| 576 |
+
* \retval CUPTI_SUCCESS
|
| 577 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
|
| 578 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 579 |
+
* does not support the API
|
| 580 |
+
*/
|
| 581 |
+
CUptiResult CUPTIAPI cuptiPCSamplingEnable(CUpti_PCSamplingEnableParams *pParams);
|
| 582 |
+
|
| 583 |
+
/**
|
| 584 |
+
* \brief Params for cuptiPCSamplingDisable
|
| 585 |
+
*/
|
| 586 |
+
typedef struct
|
| 587 |
+
{
|
| 588 |
+
/**
|
| 589 |
+
* [w] Size of the data structure i.e. CUpti_PCSamplingDisableParamsSize
|
| 590 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 591 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 592 |
+
*/
|
| 593 |
+
size_t size;
|
| 594 |
+
/**
|
| 595 |
+
* [w] Assign to NULL
|
| 596 |
+
*/
|
| 597 |
+
void* pPriv;
|
| 598 |
+
/**
|
| 599 |
+
* [w] CUcontext
|
| 600 |
+
*/
|
| 601 |
+
CUcontext ctx;
|
| 602 |
+
} CUpti_PCSamplingDisableParams;
|
| 603 |
+
#define CUpti_PCSamplingDisableParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingDisableParams, ctx)
|
| 604 |
+
|
| 605 |
+
/**
|
| 606 |
+
* \brief Disable PC sampling.
|
| 607 |
+
*
|
| 608 |
+
* For application which doesn't destroy the CUDA context explicitly,
|
| 609 |
+
* this API does the PC Sampling tear-down, joins threads and copies PC records in the buffer provided
|
| 610 |
+
* during the PC sampling configuration. PC records which can't be accommodated in the buffer are discarded.
|
| 611 |
+
*
|
| 612 |
+
* \param Refer \ref CUpti_PCSamplingDisableParams
|
| 613 |
+
*
|
| 614 |
+
* \retval CUPTI_SUCCESS
|
| 615 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
|
| 616 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 617 |
+
* does not support the API
|
| 618 |
+
*/
|
| 619 |
+
CUptiResult CUPTIAPI cuptiPCSamplingDisable(CUpti_PCSamplingDisableParams *pParams);
|
| 620 |
+
|
| 621 |
+
/**
|
| 622 |
+
* \brief Params for cuptiPCSamplingStart
|
| 623 |
+
*/
|
| 624 |
+
typedef struct
|
| 625 |
+
{
|
| 626 |
+
/**
|
| 627 |
+
* [w] Size of the data structure i.e. CUpti_PCSamplingStartParamsSize
|
| 628 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 629 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 630 |
+
*/
|
| 631 |
+
size_t size;
|
| 632 |
+
/**
|
| 633 |
+
* [w] Assign to NULL
|
| 634 |
+
*/
|
| 635 |
+
void* pPriv;
|
| 636 |
+
/**
|
| 637 |
+
* [w] CUcontext
|
| 638 |
+
*/
|
| 639 |
+
CUcontext ctx;
|
| 640 |
+
} CUpti_PCSamplingStartParams;
|
| 641 |
+
#define CUpti_PCSamplingStartParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingStartParams, ctx)
|
| 642 |
+
|
| 643 |
+
/**
|
| 644 |
+
* \brief Start PC sampling.
|
| 645 |
+
*
|
| 646 |
+
* User can collect PC Sampling data for user-defined range specified by Start/Stop APIs.
|
| 647 |
+
* This API can be used to mark starting of range. Set configuration option
|
| 648 |
+
* \brief CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL to use this API.
|
| 649 |
+
*
|
| 650 |
+
* \param Refer \ref CUpti_PCSamplingStartParams
|
| 651 |
+
*
|
| 652 |
+
* \retval CUPTI_SUCCESS
|
| 653 |
+
* \retval CUPTI_ERROR_INVALID_OPERATION if this API is called with
|
| 654 |
+
* incorrect PC Sampling configuration.
|
| 655 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
|
| 656 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 657 |
+
* does not support the API
|
| 658 |
+
*/
|
| 659 |
+
CUptiResult CUPTIAPI cuptiPCSamplingStart(CUpti_PCSamplingStartParams *pParams);
|
| 660 |
+
|
| 661 |
+
/**
|
| 662 |
+
* \brief Params for cuptiPCSamplingStop
|
| 663 |
+
*/
|
| 664 |
+
typedef struct
|
| 665 |
+
{
|
| 666 |
+
/**
|
| 667 |
+
* [w] Size of the data structure i.e. CUpti_PCSamplingStopParamsSize
|
| 668 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 669 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 670 |
+
*/
|
| 671 |
+
size_t size;
|
| 672 |
+
/**
|
| 673 |
+
* [w] Assign to NULL
|
| 674 |
+
*/
|
| 675 |
+
void* pPriv;
|
| 676 |
+
/**
|
| 677 |
+
* [w] CUcontext
|
| 678 |
+
*/
|
| 679 |
+
CUcontext ctx;
|
| 680 |
+
} CUpti_PCSamplingStopParams;
|
| 681 |
+
#define CUpti_PCSamplingStopParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingStopParams, ctx)
|
| 682 |
+
|
| 683 |
+
/**
|
| 684 |
+
* \brief Stop PC sampling.
|
| 685 |
+
*
|
| 686 |
+
* User can collect PC Sampling data for user-defined range specified by Start/Stop APIs.
|
| 687 |
+
* This API can be used to mark end of range. Set configuration option
|
| 688 |
+
* \brief CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL to use this API.
|
| 689 |
+
*
|
| 690 |
+
* \param Refer \ref CUpti_PCSamplingStopParams
|
| 691 |
+
*
|
| 692 |
+
* \retval CUPTI_SUCCESS
|
| 693 |
+
* \retval CUPTI_ERROR_INVALID_OPERATION if this API is called with
|
| 694 |
+
* incorrect PC Sampling configuration.
|
| 695 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
|
| 696 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 697 |
+
* does not support the API
|
| 698 |
+
*/
|
| 699 |
+
CUptiResult CUPTIAPI cuptiPCSamplingStop(CUpti_PCSamplingStopParams *pParams);
|
| 700 |
+
|
| 701 |
+
/**
|
| 702 |
+
* \brief Params for cuptiPCSamplingGetNumStallReasons
|
| 703 |
+
*/
|
| 704 |
+
typedef struct
|
| 705 |
+
{
|
| 706 |
+
/**
|
| 707 |
+
* [w] Size of the data structure i.e. CUpti_PCSamplingGetNumStallReasonsParamsSize
|
| 708 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 709 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 710 |
+
*/
|
| 711 |
+
size_t size;
|
| 712 |
+
/**
|
| 713 |
+
* [w] Assign to NULL
|
| 714 |
+
*/
|
| 715 |
+
void* pPriv;
|
| 716 |
+
/**
|
| 717 |
+
* [w] CUcontext
|
| 718 |
+
*/
|
| 719 |
+
CUcontext ctx;
|
| 720 |
+
/**
|
| 721 |
+
* [r] Number of stall reasons
|
| 722 |
+
*/
|
| 723 |
+
size_t *numStallReasons;
|
| 724 |
+
} CUpti_PCSamplingGetNumStallReasonsParams;
|
| 725 |
+
#define CUpti_PCSamplingGetNumStallReasonsParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingGetNumStallReasonsParams, numStallReasons)
|
| 726 |
+
|
| 727 |
+
/**
|
| 728 |
+
* \brief Get PC sampling stall reason count.
|
| 729 |
+
*
|
| 730 |
+
* \param Refer \ref CUpti_PCSamplingGetNumStallReasonsParams
|
| 731 |
+
*
|
| 732 |
+
* \retval CUPTI_SUCCESS
|
| 733 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
|
| 734 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 735 |
+
* does not support the API
|
| 736 |
+
*/
|
| 737 |
+
CUptiResult CUPTIAPI cuptiPCSamplingGetNumStallReasons(CUpti_PCSamplingGetNumStallReasonsParams *pParams);
|
| 738 |
+
|
| 739 |
+
/**
|
| 740 |
+
* \brief Params for cuptiPCSamplingGetStallReasons
|
| 741 |
+
*/
|
| 742 |
+
typedef struct
|
| 743 |
+
{
|
| 744 |
+
/**
|
| 745 |
+
* [w] Size of the data structure i.e. CUpti_PCSamplingGetStallReasonsParamsSize
|
| 746 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 747 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 748 |
+
*/
|
| 749 |
+
size_t size;
|
| 750 |
+
/**
|
| 751 |
+
* [w] Assign to NULL
|
| 752 |
+
*/
|
| 753 |
+
void* pPriv;
|
| 754 |
+
/**
|
| 755 |
+
* [w] CUcontext
|
| 756 |
+
*/
|
| 757 |
+
CUcontext ctx;
|
| 758 |
+
/**
|
| 759 |
+
* [w] Number of stall reasons
|
| 760 |
+
*/
|
| 761 |
+
size_t numStallReasons;
|
| 762 |
+
/**
|
| 763 |
+
* [r] Stall reason index
|
| 764 |
+
*/
|
| 765 |
+
uint32_t *stallReasonIndex;
|
| 766 |
+
/**
|
| 767 |
+
* [r] Stall reasons name
|
| 768 |
+
*/
|
| 769 |
+
char **stallReasons;
|
| 770 |
+
} CUpti_PCSamplingGetStallReasonsParams;
|
| 771 |
+
#define CUpti_PCSamplingGetStallReasonsParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_PCSamplingGetStallReasonsParams, stallReasons)
|
| 772 |
+
|
| 773 |
+
/**
|
| 774 |
+
* \brief Get PC sampling stall reasons.
|
| 775 |
+
*
|
| 776 |
+
* \param Refer \ref CUpti_PCSamplingGetStallReasonsParams
|
| 777 |
+
*
|
| 778 |
+
* \retval CUPTI_SUCCESS
|
| 779 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if any \p pParams is not valid
|
| 780 |
+
* \retval CUPTI_ERROR_NOT_SUPPORTED indicates that the system/device
|
| 781 |
+
* does not support the API
|
| 782 |
+
*/
|
| 783 |
+
CUptiResult CUPTIAPI cuptiPCSamplingGetStallReasons(CUpti_PCSamplingGetStallReasonsParams *pParams);
|
| 784 |
+
|
| 785 |
+
/**
|
| 786 |
+
* \brief Params for cuptiGetSassToSourceCorrelation
|
| 787 |
+
*/
|
| 788 |
+
typedef struct {
|
| 789 |
+
/**
|
| 790 |
+
* [w] Size of the data structure i.e. CUpti_GetSassToSourceCorrelationParamsSize
|
| 791 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 792 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 793 |
+
*/
|
| 794 |
+
size_t size;
|
| 795 |
+
/**
|
| 796 |
+
* [w] Pointer to cubin binary where function belongs.
|
| 797 |
+
*/
|
| 798 |
+
const void* cubin;
|
| 799 |
+
/**
|
| 800 |
+
* [w] Function name to which PC belongs.
|
| 801 |
+
*/
|
| 802 |
+
const char *functionName;
|
| 803 |
+
/**
|
| 804 |
+
* [w] Size of cubin binary.
|
| 805 |
+
*/
|
| 806 |
+
size_t cubinSize;
|
| 807 |
+
/**
|
| 808 |
+
* [r] Line number in the source code.
|
| 809 |
+
*/
|
| 810 |
+
uint32_t lineNumber;
|
| 811 |
+
/**
|
| 812 |
+
* [w] PC offset
|
| 813 |
+
*/
|
| 814 |
+
uint64_t pcOffset;
|
| 815 |
+
/**
|
| 816 |
+
* [r] Path for the source file.
|
| 817 |
+
*/
|
| 818 |
+
char *fileName;
|
| 819 |
+
/**
|
| 820 |
+
* [r] Path for the directory of source file.
|
| 821 |
+
*/
|
| 822 |
+
char *dirName;
|
| 823 |
+
} CUpti_GetSassToSourceCorrelationParams;
|
| 824 |
+
#define CUpti_GetSassToSourceCorrelationParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_GetSassToSourceCorrelationParams, dirName)
|
| 825 |
+
|
| 826 |
+
/**
|
| 827 |
+
* \brief SASS to Source correlation.
|
| 828 |
+
*
|
| 829 |
+
* \param Refer \ref CUpti_GetSassToSourceCorrelationParams
|
| 830 |
+
*
|
| 831 |
+
* It is expected from user to free allocated memory for fileName and dirName after use.
|
| 832 |
+
*
|
| 833 |
+
* \retval CUPTI_SUCCESS
|
| 834 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if either of the parameters cubin or functionName
|
| 835 |
+
* is NULL or cubinSize is zero or size field is not set correctly.
|
| 836 |
+
* \retval CUPTI_ERROR_INVALID_MODULE provided cubin is invalid.
|
| 837 |
+
* \retval CUPTI_ERROR_UNKNOWN an internal error occurred.
|
| 838 |
+
* This error code is also used for cases when the function is not present in the module.
|
| 839 |
+
* A better error code will be returned in the future release.
|
| 840 |
+
*/
|
| 841 |
+
CUptiResult CUPTIAPI cuptiGetSassToSourceCorrelation(CUpti_GetSassToSourceCorrelationParams *pParams);
|
| 842 |
+
|
| 843 |
+
/**
|
| 844 |
+
* \brief Params for cuptiGetCubinCrc
|
| 845 |
+
*/
|
| 846 |
+
typedef struct {
|
| 847 |
+
/**
|
| 848 |
+
* [w] Size of configuration structure.
|
| 849 |
+
* CUPTI client should set the size of the structure. It will be used in CUPTI to check what fields are
|
| 850 |
+
* available in the structure. Used to preserve backward compatibility.
|
| 851 |
+
*/
|
| 852 |
+
size_t size;
|
| 853 |
+
/**
|
| 854 |
+
* [w] Size of cubin binary.
|
| 855 |
+
*/
|
| 856 |
+
size_t cubinSize;
|
| 857 |
+
/**
|
| 858 |
+
* [w] Pointer to cubin binary
|
| 859 |
+
*/
|
| 860 |
+
const void* cubin;
|
| 861 |
+
/**
|
| 862 |
+
* [r] Computed CRC will be stored in it.
|
| 863 |
+
*/
|
| 864 |
+
uint64_t cubinCrc;
|
| 865 |
+
} CUpti_GetCubinCrcParams;
|
| 866 |
+
#define CUpti_GetCubinCrcParamsSize CUPTI_PCSAMPLING_STRUCT_SIZE(CUpti_GetCubinCrcParams, cubinCrc)
|
| 867 |
+
|
| 868 |
+
/**
|
| 869 |
+
* \brief Get the CRC of cubin.
|
| 870 |
+
*
|
| 871 |
+
* This function returns the CRC of provided cubin binary.
|
| 872 |
+
*
|
| 873 |
+
* \param Refer \ref CUpti_GetCubinCrcParams
|
| 874 |
+
*
|
| 875 |
+
* \retval CUPTI_SUCCESS
|
| 876 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if parameter cubin is NULL or
|
| 877 |
+
* provided cubinSize is zero or size field is not set.
|
| 878 |
+
*/
|
| 879 |
+
CUptiResult CUPTIAPI cuptiGetCubinCrc(CUpti_GetCubinCrcParams *pParams);
|
| 880 |
+
|
| 881 |
+
/**
|
| 882 |
+
* \brief Function type for callback used by CUPTI to request crc of
|
| 883 |
+
* loaded module.
|
| 884 |
+
*
|
| 885 |
+
* This callback function ask for crc of provided module in function.
|
| 886 |
+
* The provided crc will be stored in PC sampling records i.e. in the field 'cubinCrc' of the PC sampling
|
| 887 |
+
* struct CUpti_PCSamplingPCData. The CRC is uses during the offline source correlation to uniquely identify the module.
|
| 888 |
+
*
|
| 889 |
+
* \param cubin The pointer to cubin binary
|
| 890 |
+
* \param cubinSize The size of cubin binary.
|
| 891 |
+
* \param cubinCrc Returns the computed crc of cubin.
|
| 892 |
+
*/
|
| 893 |
+
typedef void (CUPTIAPI *CUpti_ComputeCrcCallbackFunc)(
|
| 894 |
+
const void* cubin,
|
| 895 |
+
size_t cubinSize,
|
| 896 |
+
uint64_t *cubinCrc);
|
| 897 |
+
|
| 898 |
+
/**
|
| 899 |
+
* \brief Register callback function with CUPTI to use
|
| 900 |
+
* your own algorithm to compute cubin crc.
|
| 901 |
+
*
|
| 902 |
+
* This function registers a callback function and it gets called
|
| 903 |
+
* from CUPTI when a CUDA module is loaded.
|
| 904 |
+
*
|
| 905 |
+
* \param funcComputeCubinCrc callback is invoked when a CUDA module
|
| 906 |
+
* is loaded.
|
| 907 |
+
*
|
| 908 |
+
* \retval CUPTI_SUCCESS
|
| 909 |
+
* \retval CUPTI_ERROR_INVALID_PARAMETER if \p funcComputeCubinCrc is NULL.
|
| 910 |
+
*/
|
| 911 |
+
CUptiResult CUPTIAPI cuptiRegisterComputeCrcCallback(CUpti_ComputeCrcCallbackFunc funcComputeCubinCrc);
|
| 912 |
+
|
| 913 |
+
/** @} */ /* END CUPTI_PCSAMPLING_API */
|
| 914 |
+
|
| 915 |
+
#if defined(__GNUC__) && defined(CUPTI_LIB)
|
| 916 |
+
#pragma GCC visibility pop
|
| 917 |
+
#endif
|
| 918 |
+
|
| 919 |
+
#if defined(__cplusplus)
|
| 920 |
+
}
|
| 921 |
+
#endif
|
| 922 |
+
|
| 923 |
+
#endif /*_CUPTI_PCSAMPLING_H_*/
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (220 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_cnn_infer_v8.h
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*
|
| 51 |
+
* cudnn_cnn_infer : cuDNN's basic definitions and inference CNN functions.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#if !defined(CUDNN_CNN_INFER_H_)
|
| 55 |
+
#define CUDNN_CNN_INFER_H_
|
| 56 |
+
|
| 57 |
+
#pragma once
|
| 58 |
+
#include <cuda_runtime.h>
|
| 59 |
+
#include <stdint.h>
|
| 60 |
+
|
| 61 |
+
#include "cudnn_version.h"
|
| 62 |
+
#include "cudnn_ops_infer.h"
|
| 63 |
+
|
| 64 |
+
/* These version numbers are autogenerated, do not edit manually. */
|
| 65 |
+
#define CUDNN_CNN_INFER_MAJOR 8
|
| 66 |
+
#define CUDNN_CNN_INFER_MINOR 7
|
| 67 |
+
#define CUDNN_CNN_INFER_PATCH 0
|
| 68 |
+
|
| 69 |
+
#if (CUDNN_CNN_INFER_MAJOR != CUDNN_MAJOR) || (CUDNN_CNN_INFER_MINOR != CUDNN_MINOR) || \
|
| 70 |
+
(CUDNN_CNN_INFER_PATCH != CUDNN_PATCHLEVEL)
|
| 71 |
+
#error Version mismatch in cuDNN CNN INFER!!!
|
| 72 |
+
#endif
|
| 73 |
+
|
| 74 |
+
#if defined(__cplusplus)
|
| 75 |
+
extern "C" {
|
| 76 |
+
#endif
|
| 77 |
+
|
| 78 |
+
typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t;
|
| 79 |
+
|
| 80 |
+
/*
|
| 81 |
+
* convolution mode
|
| 82 |
+
*/
|
| 83 |
+
typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t;
|
| 84 |
+
|
| 85 |
+
/*
|
| 86 |
+
* CUDNN Reorder
|
| 87 |
+
*/
|
| 88 |
+
typedef enum {
|
| 89 |
+
CUDNN_DEFAULT_REORDER = 0,
|
| 90 |
+
CUDNN_NO_REORDER = 1,
|
| 91 |
+
} cudnnReorderType_t;
|
| 92 |
+
|
| 93 |
+
typedef struct cudnnConvolutionFwdAlgoPerfStruct {
|
| 94 |
+
cudnnConvolutionFwdAlgo_t algo;
|
| 95 |
+
cudnnStatus_t status;
|
| 96 |
+
float time;
|
| 97 |
+
size_t memory;
|
| 98 |
+
cudnnDeterminism_t determinism;
|
| 99 |
+
cudnnMathType_t mathType;
|
| 100 |
+
int reserved[3];
|
| 101 |
+
} cudnnConvolutionFwdAlgoPerf_t;
|
| 102 |
+
|
| 103 |
+
/* Create an instance of convolution descriptor */
|
| 104 |
+
cudnnStatus_t CUDNNWINAPI
|
| 105 |
+
cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc);
|
| 106 |
+
|
| 107 |
+
/* Destroy an instance of convolution descriptor */
|
| 108 |
+
cudnnStatus_t CUDNNWINAPI
|
| 109 |
+
cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc);
|
| 110 |
+
|
| 111 |
+
cudnnStatus_t CUDNNWINAPI
|
| 112 |
+
cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType);
|
| 113 |
+
|
| 114 |
+
cudnnStatus_t CUDNNWINAPI
|
| 115 |
+
cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType);
|
| 116 |
+
|
| 117 |
+
cudnnStatus_t CUDNNWINAPI
|
| 118 |
+
cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount);
|
| 119 |
+
|
| 120 |
+
cudnnStatus_t CUDNNWINAPI
|
| 121 |
+
cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount);
|
| 122 |
+
|
| 123 |
+
cudnnStatus_t CUDNNWINAPI
|
| 124 |
+
cudnnSetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType);
|
| 125 |
+
|
| 126 |
+
cudnnStatus_t CUDNNWINAPI
|
| 127 |
+
cudnnGetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType);
|
| 128 |
+
|
| 129 |
+
cudnnStatus_t CUDNNWINAPI
|
| 130 |
+
cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc,
|
| 131 |
+
int pad_h, /* zero-padding height */
|
| 132 |
+
int pad_w, /* zero-padding width */
|
| 133 |
+
int u, /* vertical filter stride */
|
| 134 |
+
int v, /* horizontal filter stride */
|
| 135 |
+
int dilation_h, /* filter dilation in the vertical dimension */
|
| 136 |
+
int dilation_w, /* filter dilation in the horizontal dimension */
|
| 137 |
+
cudnnConvolutionMode_t mode,
|
| 138 |
+
cudnnDataType_t computeType);
|
| 139 |
+
|
| 140 |
+
cudnnStatus_t CUDNNWINAPI
|
| 141 |
+
cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc,
|
| 142 |
+
int *pad_h, /* zero-padding height */
|
| 143 |
+
int *pad_w, /* zero-padding width */
|
| 144 |
+
int *u, /* vertical filter stride */
|
| 145 |
+
int *v, /* horizontal filter stride */
|
| 146 |
+
int *dilation_h, /* filter dilation in the vertical dimension */
|
| 147 |
+
int *dilation_w, /* filter dilation in the horizontal dimension */
|
| 148 |
+
cudnnConvolutionMode_t *mode,
|
| 149 |
+
cudnnDataType_t *computeType);
|
| 150 |
+
|
| 151 |
+
cudnnStatus_t CUDNNWINAPI
|
| 152 |
+
cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc,
|
| 153 |
+
int arrayLength, /* nbDims-2 size */
|
| 154 |
+
const int padA[],
|
| 155 |
+
const int filterStrideA[],
|
| 156 |
+
const int dilationA[],
|
| 157 |
+
cudnnConvolutionMode_t mode,
|
| 158 |
+
cudnnDataType_t computeType); /* convolution data type */
|
| 159 |
+
|
| 160 |
+
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
|
| 161 |
+
cudnnStatus_t CUDNNWINAPI
|
| 162 |
+
cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc,
|
| 163 |
+
int arrayLengthRequested,
|
| 164 |
+
int *arrayLength,
|
| 165 |
+
int padA[],
|
| 166 |
+
int strideA[],
|
| 167 |
+
int dilationA[],
|
| 168 |
+
cudnnConvolutionMode_t *mode,
|
| 169 |
+
cudnnDataType_t *computeType); /* convolution data type */
|
| 170 |
+
|
| 171 |
+
cudnnStatus_t CUDNNWINAPI
|
| 172 |
+
cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
|
| 173 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 174 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 175 |
+
int *n,
|
| 176 |
+
int *c,
|
| 177 |
+
int *h,
|
| 178 |
+
int *w);
|
| 179 |
+
|
| 180 |
+
/* Helper function to return the dimensions of the output tensor given a convolution descriptor */
|
| 181 |
+
cudnnStatus_t CUDNNWINAPI
|
| 182 |
+
cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc,
|
| 183 |
+
const cudnnTensorDescriptor_t inputTensorDesc,
|
| 184 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 185 |
+
int nbDims,
|
| 186 |
+
int tensorOuputDimA[]);
|
| 187 |
+
|
| 188 |
+
/* helper function to provide the convolution forward algo that fit best the requirement */
|
| 189 |
+
cudnnStatus_t CUDNNWINAPI
|
| 190 |
+
cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count);
|
| 191 |
+
|
| 192 |
+
cudnnStatus_t CUDNNWINAPI
|
| 193 |
+
cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle,
|
| 194 |
+
const cudnnTensorDescriptor_t srcDesc,
|
| 195 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 196 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 197 |
+
const cudnnTensorDescriptor_t destDesc,
|
| 198 |
+
const int requestedAlgoCount,
|
| 199 |
+
int *returnedAlgoCount,
|
| 200 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults);
|
| 201 |
+
|
| 202 |
+
cudnnStatus_t CUDNNWINAPI
|
| 203 |
+
cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle,
|
| 204 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 205 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 206 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 207 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 208 |
+
const int requestedAlgoCount,
|
| 209 |
+
int *returnedAlgoCount,
|
| 210 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults);
|
| 211 |
+
|
| 212 |
+
cudnnStatus_t CUDNNWINAPI
|
| 213 |
+
cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle,
|
| 214 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 215 |
+
const void *x,
|
| 216 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 217 |
+
const void *w,
|
| 218 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 219 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 220 |
+
void *y,
|
| 221 |
+
const int requestedAlgoCount,
|
| 222 |
+
int *returnedAlgoCount,
|
| 223 |
+
cudnnConvolutionFwdAlgoPerf_t *perfResults,
|
| 224 |
+
void *workSpace,
|
| 225 |
+
size_t workSpaceSizeInBytes);
|
| 226 |
+
|
| 227 |
+
cudnnStatus_t CUDNNWINAPI
|
| 228 |
+
cudnnIm2Col(cudnnHandle_t handle,
|
| 229 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 230 |
+
const void *x,
|
| 231 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 232 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 233 |
+
void *colBuffer);
|
| 234 |
+
|
| 235 |
+
cudnnStatus_t CUDNNWINAPI
|
| 236 |
+
cudnnReorderFilterAndBias(cudnnHandle_t handle,
|
| 237 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 238 |
+
cudnnReorderType_t reorderType,
|
| 239 |
+
const void *filterData,
|
| 240 |
+
void *reorderedFilterData,
|
| 241 |
+
int reorderBias,
|
| 242 |
+
const void *biasData,
|
| 243 |
+
void *reorderedBiasData);
|
| 244 |
+
|
| 245 |
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
|
| 246 |
+
cudnnStatus_t CUDNNWINAPI
|
| 247 |
+
cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle,
|
| 248 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 249 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 250 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 251 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 252 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 253 |
+
size_t *sizeInBytes);
|
| 254 |
+
|
| 255 |
+
/* Convolution functions: All of the form "output = alpha * Op(inputs) + beta * output" */
|
| 256 |
+
|
| 257 |
+
/* Function to perform the forward pass for batch convolution */
|
| 258 |
+
cudnnStatus_t CUDNNWINAPI
|
| 259 |
+
cudnnConvolutionForward(cudnnHandle_t handle,
|
| 260 |
+
const void *alpha,
|
| 261 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 262 |
+
const void *x,
|
| 263 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 264 |
+
const void *w,
|
| 265 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 266 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 267 |
+
void *workSpace,
|
| 268 |
+
size_t workSpaceSizeInBytes,
|
| 269 |
+
const void *beta,
|
| 270 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 271 |
+
void *y);
|
| 272 |
+
|
| 273 |
+
/* Fused conv/bias/activation operation : y = Act( alpha1 * conv(x) + alpha2 * z + bias ) */
|
| 274 |
+
cudnnStatus_t CUDNNWINAPI
|
| 275 |
+
cudnnConvolutionBiasActivationForward(cudnnHandle_t handle,
|
| 276 |
+
const void *alpha1,
|
| 277 |
+
const cudnnTensorDescriptor_t xDesc,
|
| 278 |
+
const void *x,
|
| 279 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 280 |
+
const void *w,
|
| 281 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 282 |
+
cudnnConvolutionFwdAlgo_t algo,
|
| 283 |
+
void *workSpace,
|
| 284 |
+
size_t workSpaceSizeInBytes,
|
| 285 |
+
const void *alpha2,
|
| 286 |
+
const cudnnTensorDescriptor_t zDesc,
|
| 287 |
+
const void *z,
|
| 288 |
+
const cudnnTensorDescriptor_t biasDesc,
|
| 289 |
+
const void *bias,
|
| 290 |
+
const cudnnActivationDescriptor_t activationDesc,
|
| 291 |
+
const cudnnTensorDescriptor_t yDesc,
|
| 292 |
+
void *y);
|
| 293 |
+
|
| 294 |
+
/* helper function to provide the convolution backward data algo that fit best the requirement */
|
| 295 |
+
|
| 296 |
+
typedef struct cudnnConvolutionBwdDataAlgoPerfStruct {
|
| 297 |
+
cudnnConvolutionBwdDataAlgo_t algo;
|
| 298 |
+
cudnnStatus_t status;
|
| 299 |
+
float time;
|
| 300 |
+
size_t memory;
|
| 301 |
+
cudnnDeterminism_t determinism;
|
| 302 |
+
cudnnMathType_t mathType;
|
| 303 |
+
int reserved[3];
|
| 304 |
+
} cudnnConvolutionBwdDataAlgoPerf_t;
|
| 305 |
+
|
| 306 |
+
cudnnStatus_t CUDNNWINAPI
|
| 307 |
+
cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count);
|
| 308 |
+
|
| 309 |
+
cudnnStatus_t CUDNNWINAPI
|
| 310 |
+
cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle,
|
| 311 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 312 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 313 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 314 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 315 |
+
const int requestedAlgoCount,
|
| 316 |
+
int *returnedAlgoCount,
|
| 317 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
|
| 318 |
+
|
| 319 |
+
cudnnStatus_t CUDNNWINAPI
|
| 320 |
+
cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle,
|
| 321 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 322 |
+
const void *w,
|
| 323 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 324 |
+
const void *dy,
|
| 325 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 326 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 327 |
+
void *dx,
|
| 328 |
+
const int requestedAlgoCount,
|
| 329 |
+
int *returnedAlgoCount,
|
| 330 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults,
|
| 331 |
+
void *workSpace,
|
| 332 |
+
size_t workSpaceSizeInBytes);
|
| 333 |
+
|
| 334 |
+
cudnnStatus_t CUDNNWINAPI
|
| 335 |
+
cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle,
|
| 336 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 337 |
+
const cudnnTensorDescriptor_t diffDesc,
|
| 338 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 339 |
+
const cudnnTensorDescriptor_t gradDesc,
|
| 340 |
+
const int requestedAlgoCount,
|
| 341 |
+
int *returnedAlgoCount,
|
| 342 |
+
cudnnConvolutionBwdDataAlgoPerf_t *perfResults);
|
| 343 |
+
|
| 344 |
+
/*
|
| 345 |
+
* convolution algorithm (which requires potentially some workspace)
|
| 346 |
+
*/
|
| 347 |
+
|
| 348 |
+
/* Helper function to return the minimum size of the workspace to be passed to the convolution given an algo*/
|
| 349 |
+
cudnnStatus_t CUDNNWINAPI
|
| 350 |
+
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle,
|
| 351 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 352 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 353 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 354 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 355 |
+
cudnnConvolutionBwdDataAlgo_t algo,
|
| 356 |
+
size_t *sizeInBytes);
|
| 357 |
+
|
| 358 |
+
cudnnStatus_t CUDNNWINAPI
|
| 359 |
+
cudnnConvolutionBackwardData(cudnnHandle_t handle,
|
| 360 |
+
const void *alpha,
|
| 361 |
+
const cudnnFilterDescriptor_t wDesc,
|
| 362 |
+
const void *w,
|
| 363 |
+
const cudnnTensorDescriptor_t dyDesc,
|
| 364 |
+
const void *dy,
|
| 365 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 366 |
+
cudnnConvolutionBwdDataAlgo_t algo,
|
| 367 |
+
void *workSpace,
|
| 368 |
+
size_t workSpaceSizeInBytes,
|
| 369 |
+
const void *beta,
|
| 370 |
+
const cudnnTensorDescriptor_t dxDesc,
|
| 371 |
+
void *dx);
|
| 372 |
+
|
| 373 |
+
/* Helper function to calculate folding descriptors for dgrad */
|
| 374 |
+
cudnnStatus_t CUDNNWINAPI
|
| 375 |
+
cudnnGetFoldedConvBackwardDataDescriptors(const cudnnHandle_t handle,
|
| 376 |
+
const cudnnFilterDescriptor_t filterDesc,
|
| 377 |
+
const cudnnTensorDescriptor_t diffDesc,
|
| 378 |
+
const cudnnConvolutionDescriptor_t convDesc,
|
| 379 |
+
const cudnnTensorDescriptor_t gradDesc,
|
| 380 |
+
const cudnnTensorFormat_t transformFormat,
|
| 381 |
+
cudnnFilterDescriptor_t foldedFilterDesc,
|
| 382 |
+
cudnnTensorDescriptor_t paddedDiffDesc,
|
| 383 |
+
cudnnConvolutionDescriptor_t foldedConvDesc,
|
| 384 |
+
cudnnTensorDescriptor_t foldedGradDesc,
|
| 385 |
+
cudnnTensorTransformDescriptor_t filterFoldTransDesc,
|
| 386 |
+
cudnnTensorTransformDescriptor_t diffPadTransDesc,
|
| 387 |
+
cudnnTensorTransformDescriptor_t gradFoldTransDesc,
|
| 388 |
+
cudnnTensorTransformDescriptor_t gradUnfoldTransDesc);
|
| 389 |
+
|
| 390 |
+
/* cudnnFusedOps... */
|
| 391 |
+
struct cudnnFusedOpsConstParamStruct;
|
| 392 |
+
typedef struct cudnnFusedOpsConstParamStruct *cudnnFusedOpsConstParamPack_t;
|
| 393 |
+
|
| 394 |
+
struct cudnnFusedOpsVariantParamStruct;
|
| 395 |
+
typedef struct cudnnFusedOpsVariantParamStruct *cudnnFusedOpsVariantParamPack_t;
|
| 396 |
+
|
| 397 |
+
struct cudnnFusedOpsPlanStruct;
|
| 398 |
+
typedef struct cudnnFusedOpsPlanStruct *cudnnFusedOpsPlan_t;
|
| 399 |
+
|
| 400 |
+
typedef enum {
|
| 401 |
+
/* each op in [ ] can be disabled by passing NULL ptr */
|
| 402 |
+
/* [per channel scale], [per channel bias], [activation], convolution, [generate BN stats] */
|
| 403 |
+
CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS = 0,
|
| 404 |
+
/* [per channel scale], [per channel bias], [activation], convolutionBackwardWeights */
|
| 405 |
+
CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD = 1,
|
| 406 |
+
/* utility for BN training in BN-conv fusion */
|
| 407 |
+
/* computes the equivalent scale and bias from ySum ySqSum and learned scale, bias */
|
| 408 |
+
/* optionally update running stats and generate saved stats */
|
| 409 |
+
CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING = 2,
|
| 410 |
+
/* utility for BN inference in BN-conv fusion */
|
| 411 |
+
/* computes the equivalent scale and bias from learned running stats and learned scale, bias */
|
| 412 |
+
CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE = 3,
|
| 413 |
+
/* reserved for future use: convolution, [per channel scale], [per channel bias], [residual add], [activation] */
|
| 414 |
+
CUDNN_FUSED_CONV_SCALE_BIAS_ADD_ACTIVATION = 4,
|
| 415 |
+
/* reserved for future use: [per channel scale], [per channel bias], [residual add], activation, bitmask */
|
| 416 |
+
CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK = 5,
|
| 417 |
+
/* reserved for future use */
|
| 418 |
+
CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM = 6,
|
| 419 |
+
} cudnnFusedOps_t;
|
| 420 |
+
|
| 421 |
+
typedef enum {
|
| 422 |
+
/* set XDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 423 |
+
/* get XDESC: pass previously created cudnnTensorDescriptor_t */
|
| 424 |
+
CUDNN_PARAM_XDESC = 0,
|
| 425 |
+
/* set/get XDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 426 |
+
CUDNN_PARAM_XDATA_PLACEHOLDER = 1,
|
| 427 |
+
/* set/get BN_MODE: pass cudnnBatchNormMode_t* */
|
| 428 |
+
CUDNN_PARAM_BN_MODE = 2,
|
| 429 |
+
/* set CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 430 |
+
/* get CUDNN_PARAM_BN_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 431 |
+
CUDNN_PARAM_BN_EQSCALEBIAS_DESC = 3,
|
| 432 |
+
/* set/get BN_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 433 |
+
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER = 4,
|
| 434 |
+
/* set/get BN_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 435 |
+
CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER = 5,
|
| 436 |
+
/* set ACTIVATION_DESC: pass previously initialized cudnnActivationDescriptor_t */
|
| 437 |
+
/* get ACTIVATION_DESC: pass previously created cudnnActivationDescriptor_t */
|
| 438 |
+
CUDNN_PARAM_ACTIVATION_DESC = 6,
|
| 439 |
+
/* set CONV_DESC: pass previously initialized cudnnConvolutionDescriptor_t */
|
| 440 |
+
/* get CONV_DESC: pass previously created cudnnConvolutionDescriptor_t */
|
| 441 |
+
CUDNN_PARAM_CONV_DESC = 7,
|
| 442 |
+
/* set WDESC: pass previously initialized cudnnFilterDescriptor_t */
|
| 443 |
+
/* get WDESC: pass previously created cudnnFilterDescriptor_t */
|
| 444 |
+
CUDNN_PARAM_WDESC = 8,
|
| 445 |
+
/* set/get WDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 446 |
+
CUDNN_PARAM_WDATA_PLACEHOLDER = 9,
|
| 447 |
+
/* set DWDESC: pass previously initialized cudnnFilterDescriptor_t */
|
| 448 |
+
/* get DWDESC: pass previously created cudnnFilterDescriptor_t */
|
| 449 |
+
CUDNN_PARAM_DWDESC = 10,
|
| 450 |
+
/* set/get DWDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 451 |
+
CUDNN_PARAM_DWDATA_PLACEHOLDER = 11,
|
| 452 |
+
/* set YDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 453 |
+
/* get YDESC: pass previously created cudnnTensorDescriptor_t */
|
| 454 |
+
CUDNN_PARAM_YDESC = 12,
|
| 455 |
+
/* set/get YDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 456 |
+
CUDNN_PARAM_YDATA_PLACEHOLDER = 13,
|
| 457 |
+
/* set DYDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 458 |
+
/* get DYDESC: pass previously created cudnnTensorDescriptor_t */
|
| 459 |
+
CUDNN_PARAM_DYDESC = 14,
|
| 460 |
+
/* set/get DYDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 461 |
+
CUDNN_PARAM_DYDATA_PLACEHOLDER = 15,
|
| 462 |
+
/* set YSTATS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 463 |
+
/* get YSTATS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 464 |
+
CUDNN_PARAM_YSTATS_DESC = 16,
|
| 465 |
+
/* set/get YSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 466 |
+
CUDNN_PARAM_YSUM_PLACEHOLDER = 17,
|
| 467 |
+
/* set/get YSQSUM_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 468 |
+
CUDNN_PARAM_YSQSUM_PLACEHOLDER = 18,
|
| 469 |
+
/* set CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 470 |
+
/* get CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 471 |
+
CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC = 19,
|
| 472 |
+
/* set/get CUDNN_PARAM_BN_SCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 473 |
+
CUDNN_PARAM_BN_SCALE_PLACEHOLDER = 20,
|
| 474 |
+
/* set/get CUDNN_PARAM_BN_BIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 475 |
+
CUDNN_PARAM_BN_BIAS_PLACEHOLDER = 21,
|
| 476 |
+
/* set/get CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 477 |
+
CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER = 22,
|
| 478 |
+
/* set/get CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 479 |
+
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER = 23,
|
| 480 |
+
/* set/get CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 481 |
+
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER = 24,
|
| 482 |
+
/* set/get CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 483 |
+
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER = 25,
|
| 484 |
+
|
| 485 |
+
/* set ZDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 486 |
+
/* get ZDESC: pass previously created cudnnTensorDescriptor_t */
|
| 487 |
+
CUDNN_PARAM_ZDESC = 26,
|
| 488 |
+
/* set/get ZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 489 |
+
CUDNN_PARAM_ZDATA_PLACEHOLDER = 27,
|
| 490 |
+
/* set BN_Z_EQSCALEBIAS_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 491 |
+
/* get BN_Z_EQSCALEBIAS_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 492 |
+
CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC = 28,
|
| 493 |
+
/* set/get BN_Z_EQSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 494 |
+
CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER = 29,
|
| 495 |
+
/* set/get BN_Z_EQBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 496 |
+
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER = 30,
|
| 497 |
+
|
| 498 |
+
/* set ACTIVATION_BITMASK_DESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 499 |
+
/* get ACTIVATION_BITMASK_DESC: pass previously created cudnnTensorDescriptor_t */
|
| 500 |
+
CUDNN_PARAM_ACTIVATION_BITMASK_DESC = 31,
|
| 501 |
+
/* set/get ACTIVATION_BITMASK_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 502 |
+
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER = 32,
|
| 503 |
+
|
| 504 |
+
/* set DXDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 505 |
+
/* get DXDESC: pass previously created cudnnTensorDescriptor_t */
|
| 506 |
+
CUDNN_PARAM_DXDESC = 33,
|
| 507 |
+
/* set/get DXDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 508 |
+
CUDNN_PARAM_DXDATA_PLACEHOLDER = 34,
|
| 509 |
+
/* set DZDESC: pass previously initialized cudnnTensorDescriptor_t */
|
| 510 |
+
/* get DZDESC: pass previously created cudnnTensorDescriptor_t */
|
| 511 |
+
CUDNN_PARAM_DZDESC = 35,
|
| 512 |
+
/* set/get DZDATA_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 513 |
+
CUDNN_PARAM_DZDATA_PLACEHOLDER = 36,
|
| 514 |
+
/* set/get CUDNN_PARAM_BN_DSCALE_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 515 |
+
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER = 37,
|
| 516 |
+
/* set/get CUDNN_PARAM_BN_DBIAS_PLACEHOLDER: pass cudnnFusedOpsPointerPlaceHolder_t* */
|
| 517 |
+
CUDNN_PARAM_BN_DBIAS_PLACEHOLDER = 38,
|
| 518 |
+
} cudnnFusedOpsConstParamLabel_t;
|
| 519 |
+
|
| 520 |
+
typedef enum {
|
| 521 |
+
CUDNN_PTR_NULL = 0,
|
| 522 |
+
CUDNN_PTR_ELEM_ALIGNED = 1,
|
| 523 |
+
CUDNN_PTR_16B_ALIGNED = 2,
|
| 524 |
+
} cudnnFusedOpsPointerPlaceHolder_t;
|
| 525 |
+
|
| 526 |
+
typedef enum {
|
| 527 |
+
/* set: pass void* pointing to dev memory */
|
| 528 |
+
/* get: pass void** pointing to host memory */
|
| 529 |
+
CUDNN_PTR_XDATA = 0,
|
| 530 |
+
CUDNN_PTR_BN_EQSCALE = 1,
|
| 531 |
+
CUDNN_PTR_BN_EQBIAS = 2,
|
| 532 |
+
CUDNN_PTR_WDATA = 3,
|
| 533 |
+
CUDNN_PTR_DWDATA = 4,
|
| 534 |
+
CUDNN_PTR_YDATA = 5,
|
| 535 |
+
CUDNN_PTR_DYDATA = 6,
|
| 536 |
+
CUDNN_PTR_YSUM = 7,
|
| 537 |
+
CUDNN_PTR_YSQSUM = 8,
|
| 538 |
+
CUDNN_PTR_WORKSPACE = 9,
|
| 539 |
+
CUDNN_PTR_BN_SCALE = 10,
|
| 540 |
+
CUDNN_PTR_BN_BIAS = 11,
|
| 541 |
+
CUDNN_PTR_BN_SAVED_MEAN = 12,
|
| 542 |
+
CUDNN_PTR_BN_SAVED_INVSTD = 13,
|
| 543 |
+
CUDNN_PTR_BN_RUNNING_MEAN = 14,
|
| 544 |
+
CUDNN_PTR_BN_RUNNING_VAR = 15,
|
| 545 |
+
CUDNN_PTR_ZDATA = 16,
|
| 546 |
+
CUDNN_PTR_BN_Z_EQSCALE = 17,
|
| 547 |
+
CUDNN_PTR_BN_Z_EQBIAS = 18,
|
| 548 |
+
CUDNN_PTR_ACTIVATION_BITMASK = 19,
|
| 549 |
+
CUDNN_PTR_DXDATA = 20,
|
| 550 |
+
CUDNN_PTR_DZDATA = 21,
|
| 551 |
+
CUDNN_PTR_BN_DSCALE = 22,
|
| 552 |
+
CUDNN_PTR_BN_DBIAS = 23,
|
| 553 |
+
|
| 554 |
+
/* set/get: pass size_t* pointing to host memory */
|
| 555 |
+
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES = 100,
|
| 556 |
+
/* set/get: pass int64_t* pointing to host memory */
|
| 557 |
+
CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT = 101,
|
| 558 |
+
/* set/get: pass double* pointing to host memory */
|
| 559 |
+
CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR = 102,
|
| 560 |
+
/* set/get: pass double* pointing to host memory */
|
| 561 |
+
CUDNN_SCALAR_DOUBLE_BN_EPSILON = 103,
|
| 562 |
+
} cudnnFusedOpsVariantParamLabel_t;
|
| 563 |
+
|
| 564 |
+
cudnnStatus_t CUDNNWINAPI
|
| 565 |
+
cudnnCnnInferVersionCheck(void);
|
| 566 |
+
|
| 567 |
+
#if defined(__cplusplus)
|
| 568 |
+
}
|
| 569 |
+
#endif
|
| 570 |
+
|
| 571 |
+
#endif /* CUDNN_CNN_INFER_H_ */
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn_version.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2017-2022 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* This source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* These Licensed Deliverables contained herein is PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/**
|
| 51 |
+
* \file: The master cuDNN version file.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#ifndef CUDNN_VERSION_H_
|
| 55 |
+
#define CUDNN_VERSION_H_
|
| 56 |
+
|
| 57 |
+
#define CUDNN_MAJOR 8
|
| 58 |
+
#define CUDNN_MINOR 7
|
| 59 |
+
#define CUDNN_PATCHLEVEL 0
|
| 60 |
+
|
| 61 |
+
#define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
|
| 62 |
+
|
| 63 |
+
/* cannot use constexpr here since this is a C-only file */
|
| 64 |
+
/* Below is the max SM version this cuDNN library is aware of and supports natively */
|
| 65 |
+
|
| 66 |
+
#define CUDNN_MAX_SM_MAJOR_NUMBER 9
|
| 67 |
+
#define CUDNN_MAX_SM_MINOR_NUMBER 0
|
| 68 |
+
#define CUDNN_MAX_DEVICE_VERSION (CUDNN_MAX_SM_MAJOR_NUMBER * 100) + (CUDNN_MAX_SM_MINOR_NUMBER * 10)
|
| 69 |
+
|
| 70 |
+
#endif /* CUDNN_VERSION_H */
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (213 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/cudalibxt.h
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Copyright 2013,2014 NVIDIA Corporation. All rights reserved.
|
| 2 |
+
*
|
| 3 |
+
* NOTICE TO LICENSEE:
|
| 4 |
+
*
|
| 5 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 6 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 7 |
+
* international Copyright laws.
|
| 8 |
+
*
|
| 9 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 10 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 11 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 12 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 13 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 14 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 15 |
+
* of the Licensed Deliverables to any third party without the express
|
| 16 |
+
* written consent of NVIDIA is prohibited.
|
| 17 |
+
*
|
| 18 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 19 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 20 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 21 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 22 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 23 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 24 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 25 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 26 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 27 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 28 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 29 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 30 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 31 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 32 |
+
*
|
| 33 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 34 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 35 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 36 |
+
* computer software documentation" as such terms are used in 48
|
| 37 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 38 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 39 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 40 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 41 |
+
* only those rights set forth herein.
|
| 42 |
+
*
|
| 43 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 44 |
+
* software must include, in the user documentation and internal
|
| 45 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 46 |
+
* Users Notice.
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
/*!
|
| 50 |
+
* \file cudalibxt.h
|
| 51 |
+
* \brief Public header file for the NVIDIA library multi-GPU support structures
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
#ifndef _CUDA_LIB_XT_H_
|
| 55 |
+
#define _CUDA_LIB_XT_H_
|
| 56 |
+
#include <cuda_runtime.h>
|
| 57 |
+
|
| 58 |
+
#define CUDA_XT_DESCRIPTOR_VERSION 0x01000000 // This is added to CUDART_VERSION
|
| 59 |
+
|
| 60 |
+
enum cudaXtCopyType_t {
|
| 61 |
+
LIB_XT_COPY_HOST_TO_DEVICE,
|
| 62 |
+
LIB_XT_COPY_DEVICE_TO_HOST,
|
| 63 |
+
LIB_XT_COPY_DEVICE_TO_DEVICE
|
| 64 |
+
} ;
|
| 65 |
+
typedef enum cudaXtCopyType_t cudaLibXtCopyType;
|
| 66 |
+
|
| 67 |
+
enum libFormat_t {
|
| 68 |
+
LIB_FORMAT_CUFFT = 0x0,
|
| 69 |
+
LIB_FORMAT_UNDEFINED = 0x1
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
typedef enum libFormat_t libFormat;
|
| 73 |
+
|
| 74 |
+
#define MAX_CUDA_DESCRIPTOR_GPUS 64
|
| 75 |
+
|
| 76 |
+
struct cudaXtDesc_t{
|
| 77 |
+
int version; //descriptor version
|
| 78 |
+
int nGPUs; //number of GPUs
|
| 79 |
+
int GPUs[MAX_CUDA_DESCRIPTOR_GPUS]; //array of device IDs
|
| 80 |
+
void *data[MAX_CUDA_DESCRIPTOR_GPUS]; //array of pointers to data, one per GPU
|
| 81 |
+
size_t size[MAX_CUDA_DESCRIPTOR_GPUS]; //array of data sizes, one per GPU
|
| 82 |
+
void *cudaXtState; //opaque CUDA utility structure
|
| 83 |
+
};
|
| 84 |
+
typedef struct cudaXtDesc_t cudaXtDesc;
|
| 85 |
+
|
| 86 |
+
struct cudaLibXtDesc_t{
|
| 87 |
+
int version; //descriptor version
|
| 88 |
+
cudaXtDesc *descriptor; //multi-GPU memory descriptor
|
| 89 |
+
libFormat library; //which library recognizes the format
|
| 90 |
+
int subFormat; //library specific enumerator of sub formats
|
| 91 |
+
void *libDescriptor; //library specific descriptor e.g. FFT transform plan object
|
| 92 |
+
};
|
| 93 |
+
typedef struct cudaLibXtDesc_t cudaLibXtDesc;
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
#endif
|
| 97 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/include/cufftXt.h
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
/* Copyright 2005-2021 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* NOTICE TO LICENSEE:
|
| 5 |
+
*
|
| 6 |
+
* The source code and/or documentation ("Licensed Deliverables") are
|
| 7 |
+
* subject to NVIDIA intellectual property rights under U.S. and
|
| 8 |
+
* international Copyright laws.
|
| 9 |
+
*
|
| 10 |
+
* The Licensed Deliverables contained herein are PROPRIETARY and
|
| 11 |
+
* CONFIDENTIAL to NVIDIA and are being provided under the terms and
|
| 12 |
+
* conditions of a form of NVIDIA software license agreement by and
|
| 13 |
+
* between NVIDIA and Licensee ("License Agreement") or electronically
|
| 14 |
+
* accepted by Licensee. Notwithstanding any terms or conditions to
|
| 15 |
+
* the contrary in the License Agreement, reproduction or disclosure
|
| 16 |
+
* of the Licensed Deliverables to any third party without the express
|
| 17 |
+
* written consent of NVIDIA is prohibited.
|
| 18 |
+
*
|
| 19 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 20 |
+
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
|
| 21 |
+
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. THEY ARE
|
| 22 |
+
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
|
| 23 |
+
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
|
| 24 |
+
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 25 |
+
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
| 26 |
+
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
|
| 27 |
+
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
|
| 28 |
+
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
|
| 29 |
+
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
| 30 |
+
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
|
| 31 |
+
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
|
| 32 |
+
* OF THESE LICENSED DELIVERABLES.
|
| 33 |
+
*
|
| 34 |
+
* U.S. Government End Users. These Licensed Deliverables are a
|
| 35 |
+
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
|
| 36 |
+
* 1995), consisting of "commercial computer software" and "commercial
|
| 37 |
+
* computer software documentation" as such terms are used in 48
|
| 38 |
+
* C.F.R. 12.212 (SEPT 1995) and are provided to the U.S. Government
|
| 39 |
+
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
|
| 40 |
+
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
|
| 41 |
+
* U.S. Government End Users acquire the Licensed Deliverables with
|
| 42 |
+
* only those rights set forth herein.
|
| 43 |
+
*
|
| 44 |
+
* Any use of the Licensed Deliverables in individual and commercial
|
| 45 |
+
* software must include, in the user documentation and internal
|
| 46 |
+
* comments to the code, the above Disclaimer and U.S. Government End
|
| 47 |
+
* Users Notice.
|
| 48 |
+
*/
|
| 49 |
+
|
| 50 |
+
/*!
|
| 51 |
+
* \file cufftXt.h
|
| 52 |
+
* \brief Public header file for the NVIDIA CUDA FFT library (CUFFT)
|
| 53 |
+
*/
|
| 54 |
+
|
| 55 |
+
#ifndef _CUFFTXT_H_
|
| 56 |
+
#define _CUFFTXT_H_
|
| 57 |
+
#include "cudalibxt.h"
|
| 58 |
+
#include "cufft.h"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
#ifndef CUFFTAPI
|
| 62 |
+
#ifdef _WIN32
|
| 63 |
+
#define CUFFTAPI __stdcall
|
| 64 |
+
#else
|
| 65 |
+
#define CUFFTAPI
|
| 66 |
+
#endif
|
| 67 |
+
#endif
|
| 68 |
+
|
| 69 |
+
#ifdef __cplusplus
|
| 70 |
+
extern "C" {
|
| 71 |
+
#endif
|
| 72 |
+
|
| 73 |
+
//
|
| 74 |
+
// cufftXtSubFormat identifies the data layout of
|
| 75 |
+
// a memory descriptor owned by cufft.
|
| 76 |
+
// note that multi GPU cufft does not yet support out-of-place transforms
|
| 77 |
+
//
|
| 78 |
+
|
| 79 |
+
typedef enum cufftXtSubFormat_t {
|
| 80 |
+
CUFFT_XT_FORMAT_INPUT = 0x00, //by default input is in linear order across GPUs
|
| 81 |
+
CUFFT_XT_FORMAT_OUTPUT = 0x01, //by default output is in scrambled order depending on transform
|
| 82 |
+
CUFFT_XT_FORMAT_INPLACE = 0x02, //by default inplace is input order, which is linear across GPUs
|
| 83 |
+
CUFFT_XT_FORMAT_INPLACE_SHUFFLED = 0x03, //shuffled output order after execution of the transform
|
| 84 |
+
CUFFT_XT_FORMAT_1D_INPUT_SHUFFLED = 0x04, //shuffled input order prior to execution of 1D transforms
|
| 85 |
+
CUFFT_XT_FORMAT_DISTRIBUTED_INPUT = 0x05,
|
| 86 |
+
CUFFT_XT_FORMAT_DISTRIBUTED_OUTPUT = 0x06,
|
| 87 |
+
CUFFT_FORMAT_UNDEFINED = 0x07
|
| 88 |
+
} cufftXtSubFormat;
|
| 89 |
+
|
| 90 |
+
//
|
| 91 |
+
// cufftXtCopyType specifies the type of copy for cufftXtMemcpy
|
| 92 |
+
//
|
| 93 |
+
typedef enum cufftXtCopyType_t {
|
| 94 |
+
CUFFT_COPY_HOST_TO_DEVICE = 0x00,
|
| 95 |
+
CUFFT_COPY_DEVICE_TO_HOST = 0x01,
|
| 96 |
+
CUFFT_COPY_DEVICE_TO_DEVICE = 0x02,
|
| 97 |
+
CUFFT_COPY_UNDEFINED = 0x03
|
| 98 |
+
} cufftXtCopyType;
|
| 99 |
+
|
| 100 |
+
//
|
| 101 |
+
// cufftXtQueryType specifies the type of query for cufftXtQueryPlan
|
| 102 |
+
//
|
| 103 |
+
typedef enum cufftXtQueryType_t {
|
| 104 |
+
CUFFT_QUERY_1D_FACTORS = 0x00,
|
| 105 |
+
CUFFT_QUERY_UNDEFINED = 0x01
|
| 106 |
+
} cufftXtQueryType;
|
| 107 |
+
|
| 108 |
+
typedef struct cufftXt1dFactors_t {
|
| 109 |
+
long long int size;
|
| 110 |
+
long long int stringCount;
|
| 111 |
+
long long int stringLength;
|
| 112 |
+
long long int substringLength;
|
| 113 |
+
long long int factor1;
|
| 114 |
+
long long int factor2;
|
| 115 |
+
long long int stringMask;
|
| 116 |
+
long long int substringMask;
|
| 117 |
+
long long int factor1Mask;
|
| 118 |
+
long long int factor2Mask;
|
| 119 |
+
int stringShift;
|
| 120 |
+
int substringShift;
|
| 121 |
+
int factor1Shift;
|
| 122 |
+
int factor2Shift;
|
| 123 |
+
} cufftXt1dFactors;
|
| 124 |
+
|
| 125 |
+
//
|
| 126 |
+
// cufftXtWorkAreaPolicy specifies policy for cufftXtSetWorkAreaPolicy
|
| 127 |
+
//
|
| 128 |
+
typedef enum cufftXtWorkAreaPolicy_t {
|
| 129 |
+
CUFFT_WORKAREA_MINIMAL = 0, /* maximum reduction */
|
| 130 |
+
CUFFT_WORKAREA_USER = 1, /* use workSize parameter as limit */
|
| 131 |
+
CUFFT_WORKAREA_PERFORMANCE = 2, /* default - 1x overhead or more, maximum performance */
|
| 132 |
+
} cufftXtWorkAreaPolicy;
|
| 133 |
+
|
| 134 |
+
// multi-GPU routines
|
| 135 |
+
cufftResult CUFFTAPI cufftXtSetGPUs(cufftHandle handle, int nGPUs, int *whichGPUs);
|
| 136 |
+
|
| 137 |
+
cufftResult CUFFTAPI cufftXtMalloc(cufftHandle plan,
|
| 138 |
+
cudaLibXtDesc ** descriptor,
|
| 139 |
+
cufftXtSubFormat format);
|
| 140 |
+
|
| 141 |
+
cufftResult CUFFTAPI cufftXtMemcpy(cufftHandle plan,
|
| 142 |
+
void *dstPointer,
|
| 143 |
+
void *srcPointer,
|
| 144 |
+
cufftXtCopyType type);
|
| 145 |
+
|
| 146 |
+
cufftResult CUFFTAPI cufftXtFree(cudaLibXtDesc *descriptor);
|
| 147 |
+
|
| 148 |
+
cufftResult CUFFTAPI cufftXtSetWorkArea(cufftHandle plan, void **workArea);
|
| 149 |
+
|
| 150 |
+
cufftResult CUFFTAPI cufftXtExecDescriptorC2C(cufftHandle plan,
|
| 151 |
+
cudaLibXtDesc *input,
|
| 152 |
+
cudaLibXtDesc *output,
|
| 153 |
+
int direction);
|
| 154 |
+
|
| 155 |
+
cufftResult CUFFTAPI cufftXtExecDescriptorR2C(cufftHandle plan,
|
| 156 |
+
cudaLibXtDesc *input,
|
| 157 |
+
cudaLibXtDesc *output);
|
| 158 |
+
|
| 159 |
+
cufftResult CUFFTAPI cufftXtExecDescriptorC2R(cufftHandle plan,
|
| 160 |
+
cudaLibXtDesc *input,
|
| 161 |
+
cudaLibXtDesc *output);
|
| 162 |
+
|
| 163 |
+
cufftResult CUFFTAPI cufftXtExecDescriptorZ2Z(cufftHandle plan,
|
| 164 |
+
cudaLibXtDesc *input,
|
| 165 |
+
cudaLibXtDesc *output,
|
| 166 |
+
int direction);
|
| 167 |
+
|
| 168 |
+
cufftResult CUFFTAPI cufftXtExecDescriptorD2Z(cufftHandle plan,
|
| 169 |
+
cudaLibXtDesc *input,
|
| 170 |
+
cudaLibXtDesc *output);
|
| 171 |
+
|
| 172 |
+
cufftResult CUFFTAPI cufftXtExecDescriptorZ2D(cufftHandle plan,
|
| 173 |
+
cudaLibXtDesc *input,
|
| 174 |
+
cudaLibXtDesc *output);
|
| 175 |
+
|
| 176 |
+
// Utility functions
|
| 177 |
+
|
| 178 |
+
cufftResult CUFFTAPI cufftXtQueryPlan(cufftHandle plan, void *queryStruct, cufftXtQueryType queryType);
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
// callbacks
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
typedef enum cufftXtCallbackType_t {
|
| 185 |
+
CUFFT_CB_LD_COMPLEX = 0x0,
|
| 186 |
+
CUFFT_CB_LD_COMPLEX_DOUBLE = 0x1,
|
| 187 |
+
CUFFT_CB_LD_REAL = 0x2,
|
| 188 |
+
CUFFT_CB_LD_REAL_DOUBLE = 0x3,
|
| 189 |
+
CUFFT_CB_ST_COMPLEX = 0x4,
|
| 190 |
+
CUFFT_CB_ST_COMPLEX_DOUBLE = 0x5,
|
| 191 |
+
CUFFT_CB_ST_REAL = 0x6,
|
| 192 |
+
CUFFT_CB_ST_REAL_DOUBLE = 0x7,
|
| 193 |
+
CUFFT_CB_UNDEFINED = 0x8
|
| 194 |
+
|
| 195 |
+
} cufftXtCallbackType;
|
| 196 |
+
|
| 197 |
+
typedef cufftComplex (*cufftCallbackLoadC)(void *dataIn, size_t offset, void *callerInfo, void *sharedPointer);
|
| 198 |
+
typedef cufftDoubleComplex (*cufftCallbackLoadZ)(void *dataIn, size_t offset, void *callerInfo, void *sharedPointer);
|
| 199 |
+
typedef cufftReal (*cufftCallbackLoadR)(void *dataIn, size_t offset, void *callerInfo, void *sharedPointer);
|
| 200 |
+
typedef cufftDoubleReal(*cufftCallbackLoadD)(void *dataIn, size_t offset, void *callerInfo, void *sharedPointer);
|
| 201 |
+
|
| 202 |
+
typedef void (*cufftCallbackStoreC)(void *dataOut, size_t offset, cufftComplex element, void *callerInfo, void *sharedPointer);
|
| 203 |
+
typedef void (*cufftCallbackStoreZ)(void *dataOut, size_t offset, cufftDoubleComplex element, void *callerInfo, void *sharedPointer);
|
| 204 |
+
typedef void (*cufftCallbackStoreR)(void *dataOut, size_t offset, cufftReal element, void *callerInfo, void *sharedPointer);
|
| 205 |
+
typedef void (*cufftCallbackStoreD)(void *dataOut, size_t offset, cufftDoubleReal element, void *callerInfo, void *sharedPointer);
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
cufftResult CUFFTAPI cufftXtSetCallback(cufftHandle plan, void **callback_routine, cufftXtCallbackType cbType, void **caller_info);
|
| 209 |
+
cufftResult CUFFTAPI cufftXtClearCallback(cufftHandle plan, cufftXtCallbackType cbType);
|
| 210 |
+
cufftResult CUFFTAPI cufftXtSetCallbackSharedSize(cufftHandle plan, cufftXtCallbackType cbType, size_t sharedSize);
|
| 211 |
+
|
| 212 |
+
cufftResult CUFFTAPI cufftXtMakePlanMany(cufftHandle plan,
|
| 213 |
+
int rank,
|
| 214 |
+
long long int *n,
|
| 215 |
+
long long int *inembed,
|
| 216 |
+
long long int istride,
|
| 217 |
+
long long int idist,
|
| 218 |
+
cudaDataType inputtype,
|
| 219 |
+
long long int *onembed,
|
| 220 |
+
long long int ostride,
|
| 221 |
+
long long int odist,
|
| 222 |
+
cudaDataType outputtype,
|
| 223 |
+
long long int batch,
|
| 224 |
+
size_t *workSize,
|
| 225 |
+
cudaDataType executiontype);
|
| 226 |
+
|
| 227 |
+
cufftResult CUFFTAPI cufftXtGetSizeMany(cufftHandle plan,
|
| 228 |
+
int rank,
|
| 229 |
+
long long int *n,
|
| 230 |
+
long long int *inembed,
|
| 231 |
+
long long int istride,
|
| 232 |
+
long long int idist,
|
| 233 |
+
cudaDataType inputtype,
|
| 234 |
+
long long int *onembed,
|
| 235 |
+
long long int ostride,
|
| 236 |
+
long long int odist,
|
| 237 |
+
cudaDataType outputtype,
|
| 238 |
+
long long int batch,
|
| 239 |
+
size_t *workSize,
|
| 240 |
+
cudaDataType executiontype);
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
cufftResult CUFFTAPI cufftXtExec(cufftHandle plan,
|
| 244 |
+
void *input,
|
| 245 |
+
void *output,
|
| 246 |
+
int direction);
|
| 247 |
+
|
| 248 |
+
cufftResult CUFFTAPI cufftXtExecDescriptor(cufftHandle plan,
|
| 249 |
+
cudaLibXtDesc *input,
|
| 250 |
+
cudaLibXtDesc *output,
|
| 251 |
+
int direction);
|
| 252 |
+
|
| 253 |
+
cufftResult CUFFTAPI cufftXtSetWorkAreaPolicy(cufftHandle plan, cufftXtWorkAreaPolicy policy, size_t *workSize);
|
| 254 |
+
|
| 255 |
+
typedef struct cufftBox3d_t {
|
| 256 |
+
size_t lower[3];
|
| 257 |
+
size_t upper[3];
|
| 258 |
+
size_t strides[3];
|
| 259 |
+
} cufftBox3d;
|
| 260 |
+
|
| 261 |
+
cufftResult CUFFTAPI cufftXtSetDistribution(cufftHandle plan,
|
| 262 |
+
const cufftBox3d *box_in,
|
| 263 |
+
const cufftBox3d *box_out);
|
| 264 |
+
|
| 265 |
+
#ifdef __cplusplus
|
| 266 |
+
}
|
| 267 |
+
#endif
|
| 268 |
+
|
| 269 |
+
#endif
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/nccl/include/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia_nvtx_cu11-11.8.86.dist-info/METADATA
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: nvidia-nvtx-cu11
|
| 3 |
+
Version: 11.8.86
|
| 4 |
+
Summary: NVIDIA Tools Extension
|
| 5 |
+
Home-page: https://developer.nvidia.com/cuda-zone
|
| 6 |
+
Author: Nvidia CUDA Installer Team
|
| 7 |
+
Author-email: cuda_installer@nvidia.com
|
| 8 |
+
License: NVIDIA Proprietary Software
|
| 9 |
+
Keywords: cuda,nvidia,runtime,machine learning,deep learning
|
| 10 |
+
Classifier: Development Status :: 4 - Beta
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: Intended Audience :: Education
|
| 13 |
+
Classifier: Intended Audience :: Science/Research
|
| 14 |
+
Classifier: License :: Other/Proprietary License
|
| 15 |
+
Classifier: Natural Language :: English
|
| 16 |
+
Classifier: Programming Language :: Python :: 3
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.5
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.6
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.7
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 22 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 23 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 24 |
+
Classifier: Programming Language :: Python :: 3 :: Only
|
| 25 |
+
Classifier: Topic :: Scientific/Engineering
|
| 26 |
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
| 27 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
| 28 |
+
Classifier: Topic :: Software Development
|
| 29 |
+
Classifier: Topic :: Software Development :: Libraries
|
| 30 |
+
Classifier: Operating System :: Microsoft :: Windows
|
| 31 |
+
Classifier: Operating System :: POSIX :: Linux
|
| 32 |
+
Requires-Python: >=3
|
| 33 |
+
License-File: License.txt
|
| 34 |
+
|
| 35 |
+
A C-based API for annotating events, code ranges, and resources in your applications. Applications which integrate NVTX can use the Visual Profiler to capture and visualize these events and ranges.
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia_nvtx_cu11-11.8.86.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: bdist_wheel (0.37.1)
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py3-none-manylinux1_x86_64
|
| 5 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/__pycache__/_elffile.cpython-311.pyc
ADDED
|
Binary file (5.53 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/_parser.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Handwritten parser of dependency specifiers.
|
| 2 |
+
|
| 3 |
+
The docstring for each __parse_* function contains EBNF-inspired grammar representing
|
| 4 |
+
the implementation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import ast
|
| 10 |
+
from typing import NamedTuple, Sequence, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from ._tokenizer import DEFAULT_RULES, Tokenizer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Node:
|
| 16 |
+
def __init__(self, value: str) -> None:
|
| 17 |
+
self.value = value
|
| 18 |
+
|
| 19 |
+
def __str__(self) -> str:
|
| 20 |
+
return self.value
|
| 21 |
+
|
| 22 |
+
def __repr__(self) -> str:
|
| 23 |
+
return f"<{self.__class__.__name__}('{self}')>"
|
| 24 |
+
|
| 25 |
+
def serialize(self) -> str:
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Variable(Node):
|
| 30 |
+
def serialize(self) -> str:
|
| 31 |
+
return str(self)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Value(Node):
|
| 35 |
+
def serialize(self) -> str:
|
| 36 |
+
return f'"{self}"'
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Op(Node):
|
| 40 |
+
def serialize(self) -> str:
|
| 41 |
+
return str(self)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
MarkerVar = Union[Variable, Value]
|
| 45 |
+
MarkerItem = Tuple[MarkerVar, Op, MarkerVar]
|
| 46 |
+
MarkerAtom = Union[MarkerItem, Sequence["MarkerAtom"]]
|
| 47 |
+
MarkerList = Sequence[Union["MarkerList", MarkerAtom, str]]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ParsedRequirement(NamedTuple):
|
| 51 |
+
name: str
|
| 52 |
+
url: str
|
| 53 |
+
extras: list[str]
|
| 54 |
+
specifier: str
|
| 55 |
+
marker: MarkerList | None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# --------------------------------------------------------------------------------------
|
| 59 |
+
# Recursive descent parser for dependency specifier
|
| 60 |
+
# --------------------------------------------------------------------------------------
|
| 61 |
+
def parse_requirement(source: str) -> ParsedRequirement:
|
| 62 |
+
return _parse_requirement(Tokenizer(source, rules=DEFAULT_RULES))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _parse_requirement(tokenizer: Tokenizer) -> ParsedRequirement:
|
| 66 |
+
"""
|
| 67 |
+
requirement = WS? IDENTIFIER WS? extras WS? requirement_details
|
| 68 |
+
"""
|
| 69 |
+
tokenizer.consume("WS")
|
| 70 |
+
|
| 71 |
+
name_token = tokenizer.expect(
|
| 72 |
+
"IDENTIFIER", expected="package name at the start of dependency specifier"
|
| 73 |
+
)
|
| 74 |
+
name = name_token.text
|
| 75 |
+
tokenizer.consume("WS")
|
| 76 |
+
|
| 77 |
+
extras = _parse_extras(tokenizer)
|
| 78 |
+
tokenizer.consume("WS")
|
| 79 |
+
|
| 80 |
+
url, specifier, marker = _parse_requirement_details(tokenizer)
|
| 81 |
+
tokenizer.expect("END", expected="end of dependency specifier")
|
| 82 |
+
|
| 83 |
+
return ParsedRequirement(name, url, extras, specifier, marker)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _parse_requirement_details(
|
| 87 |
+
tokenizer: Tokenizer,
|
| 88 |
+
) -> tuple[str, str, MarkerList | None]:
|
| 89 |
+
"""
|
| 90 |
+
requirement_details = AT URL (WS requirement_marker?)?
|
| 91 |
+
| specifier WS? (requirement_marker)?
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
specifier = ""
|
| 95 |
+
url = ""
|
| 96 |
+
marker = None
|
| 97 |
+
|
| 98 |
+
if tokenizer.check("AT"):
|
| 99 |
+
tokenizer.read()
|
| 100 |
+
tokenizer.consume("WS")
|
| 101 |
+
|
| 102 |
+
url_start = tokenizer.position
|
| 103 |
+
url = tokenizer.expect("URL", expected="URL after @").text
|
| 104 |
+
if tokenizer.check("END", peek=True):
|
| 105 |
+
return (url, specifier, marker)
|
| 106 |
+
|
| 107 |
+
tokenizer.expect("WS", expected="whitespace after URL")
|
| 108 |
+
|
| 109 |
+
# The input might end after whitespace.
|
| 110 |
+
if tokenizer.check("END", peek=True):
|
| 111 |
+
return (url, specifier, marker)
|
| 112 |
+
|
| 113 |
+
marker = _parse_requirement_marker(
|
| 114 |
+
tokenizer, span_start=url_start, after="URL and whitespace"
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
specifier_start = tokenizer.position
|
| 118 |
+
specifier = _parse_specifier(tokenizer)
|
| 119 |
+
tokenizer.consume("WS")
|
| 120 |
+
|
| 121 |
+
if tokenizer.check("END", peek=True):
|
| 122 |
+
return (url, specifier, marker)
|
| 123 |
+
|
| 124 |
+
marker = _parse_requirement_marker(
|
| 125 |
+
tokenizer,
|
| 126 |
+
span_start=specifier_start,
|
| 127 |
+
after=(
|
| 128 |
+
"version specifier"
|
| 129 |
+
if specifier
|
| 130 |
+
else "name and no valid version specifier"
|
| 131 |
+
),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return (url, specifier, marker)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _parse_requirement_marker(
|
| 138 |
+
tokenizer: Tokenizer, *, span_start: int, after: str
|
| 139 |
+
) -> MarkerList:
|
| 140 |
+
"""
|
| 141 |
+
requirement_marker = SEMICOLON marker WS?
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
if not tokenizer.check("SEMICOLON"):
|
| 145 |
+
tokenizer.raise_syntax_error(
|
| 146 |
+
f"Expected end or semicolon (after {after})",
|
| 147 |
+
span_start=span_start,
|
| 148 |
+
)
|
| 149 |
+
tokenizer.read()
|
| 150 |
+
|
| 151 |
+
marker = _parse_marker(tokenizer)
|
| 152 |
+
tokenizer.consume("WS")
|
| 153 |
+
|
| 154 |
+
return marker
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _parse_extras(tokenizer: Tokenizer) -> list[str]:
|
| 158 |
+
"""
|
| 159 |
+
extras = (LEFT_BRACKET wsp* extras_list? wsp* RIGHT_BRACKET)?
|
| 160 |
+
"""
|
| 161 |
+
if not tokenizer.check("LEFT_BRACKET", peek=True):
|
| 162 |
+
return []
|
| 163 |
+
|
| 164 |
+
with tokenizer.enclosing_tokens(
|
| 165 |
+
"LEFT_BRACKET",
|
| 166 |
+
"RIGHT_BRACKET",
|
| 167 |
+
around="extras",
|
| 168 |
+
):
|
| 169 |
+
tokenizer.consume("WS")
|
| 170 |
+
extras = _parse_extras_list(tokenizer)
|
| 171 |
+
tokenizer.consume("WS")
|
| 172 |
+
|
| 173 |
+
return extras
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _parse_extras_list(tokenizer: Tokenizer) -> list[str]:
|
| 177 |
+
"""
|
| 178 |
+
extras_list = identifier (wsp* ',' wsp* identifier)*
|
| 179 |
+
"""
|
| 180 |
+
extras: list[str] = []
|
| 181 |
+
|
| 182 |
+
if not tokenizer.check("IDENTIFIER"):
|
| 183 |
+
return extras
|
| 184 |
+
|
| 185 |
+
extras.append(tokenizer.read().text)
|
| 186 |
+
|
| 187 |
+
while True:
|
| 188 |
+
tokenizer.consume("WS")
|
| 189 |
+
if tokenizer.check("IDENTIFIER", peek=True):
|
| 190 |
+
tokenizer.raise_syntax_error("Expected comma between extra names")
|
| 191 |
+
elif not tokenizer.check("COMMA"):
|
| 192 |
+
break
|
| 193 |
+
|
| 194 |
+
tokenizer.read()
|
| 195 |
+
tokenizer.consume("WS")
|
| 196 |
+
|
| 197 |
+
extra_token = tokenizer.expect("IDENTIFIER", expected="extra name after comma")
|
| 198 |
+
extras.append(extra_token.text)
|
| 199 |
+
|
| 200 |
+
return extras
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _parse_specifier(tokenizer: Tokenizer) -> str:
|
| 204 |
+
"""
|
| 205 |
+
specifier = LEFT_PARENTHESIS WS? version_many WS? RIGHT_PARENTHESIS
|
| 206 |
+
| WS? version_many WS?
|
| 207 |
+
"""
|
| 208 |
+
with tokenizer.enclosing_tokens(
|
| 209 |
+
"LEFT_PARENTHESIS",
|
| 210 |
+
"RIGHT_PARENTHESIS",
|
| 211 |
+
around="version specifier",
|
| 212 |
+
):
|
| 213 |
+
tokenizer.consume("WS")
|
| 214 |
+
parsed_specifiers = _parse_version_many(tokenizer)
|
| 215 |
+
tokenizer.consume("WS")
|
| 216 |
+
|
| 217 |
+
return parsed_specifiers
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _parse_version_many(tokenizer: Tokenizer) -> str:
|
| 221 |
+
"""
|
| 222 |
+
version_many = (SPECIFIER (WS? COMMA WS? SPECIFIER)*)?
|
| 223 |
+
"""
|
| 224 |
+
parsed_specifiers = ""
|
| 225 |
+
while tokenizer.check("SPECIFIER"):
|
| 226 |
+
span_start = tokenizer.position
|
| 227 |
+
parsed_specifiers += tokenizer.read().text
|
| 228 |
+
if tokenizer.check("VERSION_PREFIX_TRAIL", peek=True):
|
| 229 |
+
tokenizer.raise_syntax_error(
|
| 230 |
+
".* suffix can only be used with `==` or `!=` operators",
|
| 231 |
+
span_start=span_start,
|
| 232 |
+
span_end=tokenizer.position + 1,
|
| 233 |
+
)
|
| 234 |
+
if tokenizer.check("VERSION_LOCAL_LABEL_TRAIL", peek=True):
|
| 235 |
+
tokenizer.raise_syntax_error(
|
| 236 |
+
"Local version label can only be used with `==` or `!=` operators",
|
| 237 |
+
span_start=span_start,
|
| 238 |
+
span_end=tokenizer.position,
|
| 239 |
+
)
|
| 240 |
+
tokenizer.consume("WS")
|
| 241 |
+
if not tokenizer.check("COMMA"):
|
| 242 |
+
break
|
| 243 |
+
parsed_specifiers += tokenizer.read().text
|
| 244 |
+
tokenizer.consume("WS")
|
| 245 |
+
|
| 246 |
+
return parsed_specifiers
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# --------------------------------------------------------------------------------------
|
| 250 |
+
# Recursive descent parser for marker expression
|
| 251 |
+
# --------------------------------------------------------------------------------------
|
| 252 |
+
def parse_marker(source: str) -> MarkerList:
|
| 253 |
+
return _parse_full_marker(Tokenizer(source, rules=DEFAULT_RULES))
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _parse_full_marker(tokenizer: Tokenizer) -> MarkerList:
|
| 257 |
+
retval = _parse_marker(tokenizer)
|
| 258 |
+
tokenizer.expect("END", expected="end of marker expression")
|
| 259 |
+
return retval
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _parse_marker(tokenizer: Tokenizer) -> MarkerList:
|
| 263 |
+
"""
|
| 264 |
+
marker = marker_atom (BOOLOP marker_atom)+
|
| 265 |
+
"""
|
| 266 |
+
expression = [_parse_marker_atom(tokenizer)]
|
| 267 |
+
while tokenizer.check("BOOLOP"):
|
| 268 |
+
token = tokenizer.read()
|
| 269 |
+
expr_right = _parse_marker_atom(tokenizer)
|
| 270 |
+
expression.extend((token.text, expr_right))
|
| 271 |
+
return expression
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _parse_marker_atom(tokenizer: Tokenizer) -> MarkerAtom:
|
| 275 |
+
"""
|
| 276 |
+
marker_atom = WS? LEFT_PARENTHESIS WS? marker WS? RIGHT_PARENTHESIS WS?
|
| 277 |
+
| WS? marker_item WS?
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
tokenizer.consume("WS")
|
| 281 |
+
if tokenizer.check("LEFT_PARENTHESIS", peek=True):
|
| 282 |
+
with tokenizer.enclosing_tokens(
|
| 283 |
+
"LEFT_PARENTHESIS",
|
| 284 |
+
"RIGHT_PARENTHESIS",
|
| 285 |
+
around="marker expression",
|
| 286 |
+
):
|
| 287 |
+
tokenizer.consume("WS")
|
| 288 |
+
marker: MarkerAtom = _parse_marker(tokenizer)
|
| 289 |
+
tokenizer.consume("WS")
|
| 290 |
+
else:
|
| 291 |
+
marker = _parse_marker_item(tokenizer)
|
| 292 |
+
tokenizer.consume("WS")
|
| 293 |
+
return marker
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _parse_marker_item(tokenizer: Tokenizer) -> MarkerItem:
|
| 297 |
+
"""
|
| 298 |
+
marker_item = WS? marker_var WS? marker_op WS? marker_var WS?
|
| 299 |
+
"""
|
| 300 |
+
tokenizer.consume("WS")
|
| 301 |
+
marker_var_left = _parse_marker_var(tokenizer)
|
| 302 |
+
tokenizer.consume("WS")
|
| 303 |
+
marker_op = _parse_marker_op(tokenizer)
|
| 304 |
+
tokenizer.consume("WS")
|
| 305 |
+
marker_var_right = _parse_marker_var(tokenizer)
|
| 306 |
+
tokenizer.consume("WS")
|
| 307 |
+
return (marker_var_left, marker_op, marker_var_right)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _parse_marker_var(tokenizer: Tokenizer) -> MarkerVar:
|
| 311 |
+
"""
|
| 312 |
+
marker_var = VARIABLE | QUOTED_STRING
|
| 313 |
+
"""
|
| 314 |
+
if tokenizer.check("VARIABLE"):
|
| 315 |
+
return process_env_var(tokenizer.read().text.replace(".", "_"))
|
| 316 |
+
elif tokenizer.check("QUOTED_STRING"):
|
| 317 |
+
return process_python_str(tokenizer.read().text)
|
| 318 |
+
else:
|
| 319 |
+
tokenizer.raise_syntax_error(
|
| 320 |
+
message="Expected a marker variable or quoted string"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def process_env_var(env_var: str) -> Variable:
|
| 325 |
+
if env_var in ("platform_python_implementation", "python_implementation"):
|
| 326 |
+
return Variable("platform_python_implementation")
|
| 327 |
+
else:
|
| 328 |
+
return Variable(env_var)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def process_python_str(python_str: str) -> Value:
|
| 332 |
+
value = ast.literal_eval(python_str)
|
| 333 |
+
return Value(str(value))
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def _parse_marker_op(tokenizer: Tokenizer) -> Op:
|
| 337 |
+
"""
|
| 338 |
+
marker_op = IN | NOT IN | OP
|
| 339 |
+
"""
|
| 340 |
+
if tokenizer.check("IN"):
|
| 341 |
+
tokenizer.read()
|
| 342 |
+
return Op("in")
|
| 343 |
+
elif tokenizer.check("NOT"):
|
| 344 |
+
tokenizer.read()
|
| 345 |
+
tokenizer.expect("WS", expected="whitespace after 'not'")
|
| 346 |
+
tokenizer.expect("IN", expected="'in' after 'not'")
|
| 347 |
+
return Op("not in")
|
| 348 |
+
elif tokenizer.check("OP"):
|
| 349 |
+
return Op(tokenizer.read().text)
|
| 350 |
+
else:
|
| 351 |
+
return tokenizer.raise_syntax_error(
|
| 352 |
+
"Expected marker operator, one of "
|
| 353 |
+
"<=, <, !=, ==, >=, >, ~=, ===, in, not in"
|
| 354 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/markers.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is dual licensed under the terms of the Apache License, Version
|
| 2 |
+
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
| 3 |
+
# for complete details.
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import operator
|
| 8 |
+
import os
|
| 9 |
+
import platform
|
| 10 |
+
import sys
|
| 11 |
+
from typing import Any, Callable, TypedDict, cast
|
| 12 |
+
|
| 13 |
+
from ._parser import MarkerAtom, MarkerList, Op, Value, Variable
|
| 14 |
+
from ._parser import parse_marker as _parse_marker
|
| 15 |
+
from ._tokenizer import ParserSyntaxError
|
| 16 |
+
from .specifiers import InvalidSpecifier, Specifier
|
| 17 |
+
from .utils import canonicalize_name
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"InvalidMarker",
|
| 21 |
+
"Marker",
|
| 22 |
+
"UndefinedComparison",
|
| 23 |
+
"UndefinedEnvironmentName",
|
| 24 |
+
"default_environment",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
Operator = Callable[[str, str], bool]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class InvalidMarker(ValueError):
|
| 31 |
+
"""
|
| 32 |
+
An invalid marker was found, users should refer to PEP 508.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class UndefinedComparison(ValueError):
|
| 37 |
+
"""
|
| 38 |
+
An invalid operation was attempted on a value that doesn't support it.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class UndefinedEnvironmentName(ValueError):
|
| 43 |
+
"""
|
| 44 |
+
A name was attempted to be used that does not exist inside of the
|
| 45 |
+
environment.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Environment(TypedDict):
|
| 50 |
+
implementation_name: str
|
| 51 |
+
"""The implementation's identifier, e.g. ``'cpython'``."""
|
| 52 |
+
|
| 53 |
+
implementation_version: str
|
| 54 |
+
"""
|
| 55 |
+
The implementation's version, e.g. ``'3.13.0a2'`` for CPython 3.13.0a2, or
|
| 56 |
+
``'7.3.13'`` for PyPy3.10 v7.3.13.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
os_name: str
|
| 60 |
+
"""
|
| 61 |
+
The value of :py:data:`os.name`. The name of the operating system dependent module
|
| 62 |
+
imported, e.g. ``'posix'``.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
platform_machine: str
|
| 66 |
+
"""
|
| 67 |
+
Returns the machine type, e.g. ``'i386'``.
|
| 68 |
+
|
| 69 |
+
An empty string if the value cannot be determined.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
platform_release: str
|
| 73 |
+
"""
|
| 74 |
+
The system's release, e.g. ``'2.2.0'`` or ``'NT'``.
|
| 75 |
+
|
| 76 |
+
An empty string if the value cannot be determined.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
platform_system: str
|
| 80 |
+
"""
|
| 81 |
+
The system/OS name, e.g. ``'Linux'``, ``'Windows'`` or ``'Java'``.
|
| 82 |
+
|
| 83 |
+
An empty string if the value cannot be determined.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
platform_version: str
|
| 87 |
+
"""
|
| 88 |
+
The system's release version, e.g. ``'#3 on degas'``.
|
| 89 |
+
|
| 90 |
+
An empty string if the value cannot be determined.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
python_full_version: str
|
| 94 |
+
"""
|
| 95 |
+
The Python version as string ``'major.minor.patchlevel'``.
|
| 96 |
+
|
| 97 |
+
Note that unlike the Python :py:data:`sys.version`, this value will always include
|
| 98 |
+
the patchlevel (it defaults to 0).
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
platform_python_implementation: str
|
| 102 |
+
"""
|
| 103 |
+
A string identifying the Python implementation, e.g. ``'CPython'``.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
python_version: str
|
| 107 |
+
"""The Python version as string ``'major.minor'``."""
|
| 108 |
+
|
| 109 |
+
sys_platform: str
|
| 110 |
+
"""
|
| 111 |
+
This string contains a platform identifier that can be used to append
|
| 112 |
+
platform-specific components to :py:data:`sys.path`, for instance.
|
| 113 |
+
|
| 114 |
+
For Unix systems, except on Linux and AIX, this is the lowercased OS name as
|
| 115 |
+
returned by ``uname -s`` with the first part of the version as returned by
|
| 116 |
+
``uname -r`` appended, e.g. ``'sunos5'`` or ``'freebsd8'``, at the time when Python
|
| 117 |
+
was built.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _normalize_extra_values(results: Any) -> Any:
|
| 122 |
+
"""
|
| 123 |
+
Normalize extra values.
|
| 124 |
+
"""
|
| 125 |
+
if isinstance(results[0], tuple):
|
| 126 |
+
lhs, op, rhs = results[0]
|
| 127 |
+
if isinstance(lhs, Variable) and lhs.value == "extra":
|
| 128 |
+
normalized_extra = canonicalize_name(rhs.value)
|
| 129 |
+
rhs = Value(normalized_extra)
|
| 130 |
+
elif isinstance(rhs, Variable) and rhs.value == "extra":
|
| 131 |
+
normalized_extra = canonicalize_name(lhs.value)
|
| 132 |
+
lhs = Value(normalized_extra)
|
| 133 |
+
results[0] = lhs, op, rhs
|
| 134 |
+
return results
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _format_marker(
|
| 138 |
+
marker: list[str] | MarkerAtom | str, first: bool | None = True
|
| 139 |
+
) -> str:
|
| 140 |
+
assert isinstance(marker, (list, tuple, str))
|
| 141 |
+
|
| 142 |
+
# Sometimes we have a structure like [[...]] which is a single item list
|
| 143 |
+
# where the single item is itself it's own list. In that case we want skip
|
| 144 |
+
# the rest of this function so that we don't get extraneous () on the
|
| 145 |
+
# outside.
|
| 146 |
+
if (
|
| 147 |
+
isinstance(marker, list)
|
| 148 |
+
and len(marker) == 1
|
| 149 |
+
and isinstance(marker[0], (list, tuple))
|
| 150 |
+
):
|
| 151 |
+
return _format_marker(marker[0])
|
| 152 |
+
|
| 153 |
+
if isinstance(marker, list):
|
| 154 |
+
inner = (_format_marker(m, first=False) for m in marker)
|
| 155 |
+
if first:
|
| 156 |
+
return " ".join(inner)
|
| 157 |
+
else:
|
| 158 |
+
return "(" + " ".join(inner) + ")"
|
| 159 |
+
elif isinstance(marker, tuple):
|
| 160 |
+
return " ".join([m.serialize() for m in marker])
|
| 161 |
+
else:
|
| 162 |
+
return marker
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
_operators: dict[str, Operator] = {
|
| 166 |
+
"in": lambda lhs, rhs: lhs in rhs,
|
| 167 |
+
"not in": lambda lhs, rhs: lhs not in rhs,
|
| 168 |
+
"<": operator.lt,
|
| 169 |
+
"<=": operator.le,
|
| 170 |
+
"==": operator.eq,
|
| 171 |
+
"!=": operator.ne,
|
| 172 |
+
">=": operator.ge,
|
| 173 |
+
">": operator.gt,
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _eval_op(lhs: str, op: Op, rhs: str) -> bool:
|
| 178 |
+
try:
|
| 179 |
+
spec = Specifier("".join([op.serialize(), rhs]))
|
| 180 |
+
except InvalidSpecifier:
|
| 181 |
+
pass
|
| 182 |
+
else:
|
| 183 |
+
return spec.contains(lhs, prereleases=True)
|
| 184 |
+
|
| 185 |
+
oper: Operator | None = _operators.get(op.serialize())
|
| 186 |
+
if oper is None:
|
| 187 |
+
raise UndefinedComparison(f"Undefined {op!r} on {lhs!r} and {rhs!r}.")
|
| 188 |
+
|
| 189 |
+
return oper(lhs, rhs)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _normalize(*values: str, key: str) -> tuple[str, ...]:
|
| 193 |
+
# PEP 685 – Comparison of extra names for optional distribution dependencies
|
| 194 |
+
# https://peps.python.org/pep-0685/
|
| 195 |
+
# > When comparing extra names, tools MUST normalize the names being
|
| 196 |
+
# > compared using the semantics outlined in PEP 503 for names
|
| 197 |
+
if key == "extra":
|
| 198 |
+
return tuple(canonicalize_name(v) for v in values)
|
| 199 |
+
|
| 200 |
+
# other environment markers don't have such standards
|
| 201 |
+
return values
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _evaluate_markers(markers: MarkerList, environment: dict[str, str]) -> bool:
|
| 205 |
+
groups: list[list[bool]] = [[]]
|
| 206 |
+
|
| 207 |
+
for marker in markers:
|
| 208 |
+
assert isinstance(marker, (list, tuple, str))
|
| 209 |
+
|
| 210 |
+
if isinstance(marker, list):
|
| 211 |
+
groups[-1].append(_evaluate_markers(marker, environment))
|
| 212 |
+
elif isinstance(marker, tuple):
|
| 213 |
+
lhs, op, rhs = marker
|
| 214 |
+
|
| 215 |
+
if isinstance(lhs, Variable):
|
| 216 |
+
environment_key = lhs.value
|
| 217 |
+
lhs_value = environment[environment_key]
|
| 218 |
+
rhs_value = rhs.value
|
| 219 |
+
else:
|
| 220 |
+
lhs_value = lhs.value
|
| 221 |
+
environment_key = rhs.value
|
| 222 |
+
rhs_value = environment[environment_key]
|
| 223 |
+
|
| 224 |
+
lhs_value, rhs_value = _normalize(lhs_value, rhs_value, key=environment_key)
|
| 225 |
+
groups[-1].append(_eval_op(lhs_value, op, rhs_value))
|
| 226 |
+
else:
|
| 227 |
+
assert marker in ["and", "or"]
|
| 228 |
+
if marker == "or":
|
| 229 |
+
groups.append([])
|
| 230 |
+
|
| 231 |
+
return any(all(item) for item in groups)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def format_full_version(info: sys._version_info) -> str:
|
| 235 |
+
version = f"{info.major}.{info.minor}.{info.micro}"
|
| 236 |
+
kind = info.releaselevel
|
| 237 |
+
if kind != "final":
|
| 238 |
+
version += kind[0] + str(info.serial)
|
| 239 |
+
return version
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def default_environment() -> Environment:
|
| 243 |
+
iver = format_full_version(sys.implementation.version)
|
| 244 |
+
implementation_name = sys.implementation.name
|
| 245 |
+
return {
|
| 246 |
+
"implementation_name": implementation_name,
|
| 247 |
+
"implementation_version": iver,
|
| 248 |
+
"os_name": os.name,
|
| 249 |
+
"platform_machine": platform.machine(),
|
| 250 |
+
"platform_release": platform.release(),
|
| 251 |
+
"platform_system": platform.system(),
|
| 252 |
+
"platform_version": platform.version(),
|
| 253 |
+
"python_full_version": platform.python_version(),
|
| 254 |
+
"platform_python_implementation": platform.python_implementation(),
|
| 255 |
+
"python_version": ".".join(platform.python_version_tuple()[:2]),
|
| 256 |
+
"sys_platform": sys.platform,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class Marker:
|
| 261 |
+
def __init__(self, marker: str) -> None:
|
| 262 |
+
# Note: We create a Marker object without calling this constructor in
|
| 263 |
+
# packaging.requirements.Requirement. If any additional logic is
|
| 264 |
+
# added here, make sure to mirror/adapt Requirement.
|
| 265 |
+
try:
|
| 266 |
+
self._markers = _normalize_extra_values(_parse_marker(marker))
|
| 267 |
+
# The attribute `_markers` can be described in terms of a recursive type:
|
| 268 |
+
# MarkerList = List[Union[Tuple[Node, ...], str, MarkerList]]
|
| 269 |
+
#
|
| 270 |
+
# For example, the following expression:
|
| 271 |
+
# python_version > "3.6" or (python_version == "3.6" and os_name == "unix")
|
| 272 |
+
#
|
| 273 |
+
# is parsed into:
|
| 274 |
+
# [
|
| 275 |
+
# (<Variable('python_version')>, <Op('>')>, <Value('3.6')>),
|
| 276 |
+
# 'and',
|
| 277 |
+
# [
|
| 278 |
+
# (<Variable('python_version')>, <Op('==')>, <Value('3.6')>),
|
| 279 |
+
# 'or',
|
| 280 |
+
# (<Variable('os_name')>, <Op('==')>, <Value('unix')>)
|
| 281 |
+
# ]
|
| 282 |
+
# ]
|
| 283 |
+
except ParserSyntaxError as e:
|
| 284 |
+
raise InvalidMarker(str(e)) from e
|
| 285 |
+
|
| 286 |
+
def __str__(self) -> str:
|
| 287 |
+
return _format_marker(self._markers)
|
| 288 |
+
|
| 289 |
+
def __repr__(self) -> str:
|
| 290 |
+
return f"<Marker('{self}')>"
|
| 291 |
+
|
| 292 |
+
def __hash__(self) -> int:
|
| 293 |
+
return hash((self.__class__.__name__, str(self)))
|
| 294 |
+
|
| 295 |
+
def __eq__(self, other: Any) -> bool:
|
| 296 |
+
if not isinstance(other, Marker):
|
| 297 |
+
return NotImplemented
|
| 298 |
+
|
| 299 |
+
return str(self) == str(other)
|
| 300 |
+
|
| 301 |
+
def evaluate(self, environment: dict[str, str] | None = None) -> bool:
|
| 302 |
+
"""Evaluate a marker.
|
| 303 |
+
|
| 304 |
+
Return the boolean from evaluating the given marker against the
|
| 305 |
+
environment. environment is an optional argument to override all or
|
| 306 |
+
part of the determined environment.
|
| 307 |
+
|
| 308 |
+
The environment is determined from the current Python process.
|
| 309 |
+
"""
|
| 310 |
+
current_environment = cast("dict[str, str]", default_environment())
|
| 311 |
+
current_environment["extra"] = ""
|
| 312 |
+
if environment is not None:
|
| 313 |
+
current_environment.update(environment)
|
| 314 |
+
# The API used to allow setting extra to None. We need to handle this
|
| 315 |
+
# case for backwards compatibility.
|
| 316 |
+
if current_environment["extra"] is None:
|
| 317 |
+
current_environment["extra"] = ""
|
| 318 |
+
|
| 319 |
+
return _evaluate_markers(
|
| 320 |
+
self._markers, _repair_python_full_version(current_environment)
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _repair_python_full_version(env: dict[str, str]) -> dict[str, str]:
|
| 325 |
+
"""
|
| 326 |
+
Work around platform.python_version() returning something that is not PEP 440
|
| 327 |
+
compliant for non-tagged Python builds.
|
| 328 |
+
"""
|
| 329 |
+
if env["python_full_version"].endswith("+"):
|
| 330 |
+
env["python_full_version"] += "local"
|
| 331 |
+
return env
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/metadata.py
ADDED
|
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import email.feedparser
|
| 4 |
+
import email.header
|
| 5 |
+
import email.message
|
| 6 |
+
import email.parser
|
| 7 |
+
import email.policy
|
| 8 |
+
import pathlib
|
| 9 |
+
import sys
|
| 10 |
+
import typing
|
| 11 |
+
from typing import (
|
| 12 |
+
Any,
|
| 13 |
+
Callable,
|
| 14 |
+
Generic,
|
| 15 |
+
Literal,
|
| 16 |
+
TypedDict,
|
| 17 |
+
cast,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from . import licenses, requirements, specifiers, utils
|
| 21 |
+
from . import version as version_module
|
| 22 |
+
from .licenses import NormalizedLicenseExpression
|
| 23 |
+
|
| 24 |
+
T = typing.TypeVar("T")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if sys.version_info >= (3, 11): # pragma: no cover
|
| 28 |
+
ExceptionGroup = ExceptionGroup
|
| 29 |
+
else: # pragma: no cover
|
| 30 |
+
|
| 31 |
+
class ExceptionGroup(Exception):
|
| 32 |
+
"""A minimal implementation of :external:exc:`ExceptionGroup` from Python 3.11.
|
| 33 |
+
|
| 34 |
+
If :external:exc:`ExceptionGroup` is already defined by Python itself,
|
| 35 |
+
that version is used instead.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
message: str
|
| 39 |
+
exceptions: list[Exception]
|
| 40 |
+
|
| 41 |
+
def __init__(self, message: str, exceptions: list[Exception]) -> None:
|
| 42 |
+
self.message = message
|
| 43 |
+
self.exceptions = exceptions
|
| 44 |
+
|
| 45 |
+
def __repr__(self) -> str:
|
| 46 |
+
return f"{self.__class__.__name__}({self.message!r}, {self.exceptions!r})"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class InvalidMetadata(ValueError):
|
| 50 |
+
"""A metadata field contains invalid data."""
|
| 51 |
+
|
| 52 |
+
field: str
|
| 53 |
+
"""The name of the field that contains invalid data."""
|
| 54 |
+
|
| 55 |
+
def __init__(self, field: str, message: str) -> None:
|
| 56 |
+
self.field = field
|
| 57 |
+
super().__init__(message)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# The RawMetadata class attempts to make as few assumptions about the underlying
|
| 61 |
+
# serialization formats as possible. The idea is that as long as a serialization
|
| 62 |
+
# formats offer some very basic primitives in *some* way then we can support
|
| 63 |
+
# serializing to and from that format.
|
| 64 |
+
class RawMetadata(TypedDict, total=False):
|
| 65 |
+
"""A dictionary of raw core metadata.
|
| 66 |
+
|
| 67 |
+
Each field in core metadata maps to a key of this dictionary (when data is
|
| 68 |
+
provided). The key is lower-case and underscores are used instead of dashes
|
| 69 |
+
compared to the equivalent core metadata field. Any core metadata field that
|
| 70 |
+
can be specified multiple times or can hold multiple values in a single
|
| 71 |
+
field have a key with a plural name. See :class:`Metadata` whose attributes
|
| 72 |
+
match the keys of this dictionary.
|
| 73 |
+
|
| 74 |
+
Core metadata fields that can be specified multiple times are stored as a
|
| 75 |
+
list or dict depending on which is appropriate for the field. Any fields
|
| 76 |
+
which hold multiple values in a single field are stored as a list.
|
| 77 |
+
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# Metadata 1.0 - PEP 241
|
| 81 |
+
metadata_version: str
|
| 82 |
+
name: str
|
| 83 |
+
version: str
|
| 84 |
+
platforms: list[str]
|
| 85 |
+
summary: str
|
| 86 |
+
description: str
|
| 87 |
+
keywords: list[str]
|
| 88 |
+
home_page: str
|
| 89 |
+
author: str
|
| 90 |
+
author_email: str
|
| 91 |
+
license: str
|
| 92 |
+
|
| 93 |
+
# Metadata 1.1 - PEP 314
|
| 94 |
+
supported_platforms: list[str]
|
| 95 |
+
download_url: str
|
| 96 |
+
classifiers: list[str]
|
| 97 |
+
requires: list[str]
|
| 98 |
+
provides: list[str]
|
| 99 |
+
obsoletes: list[str]
|
| 100 |
+
|
| 101 |
+
# Metadata 1.2 - PEP 345
|
| 102 |
+
maintainer: str
|
| 103 |
+
maintainer_email: str
|
| 104 |
+
requires_dist: list[str]
|
| 105 |
+
provides_dist: list[str]
|
| 106 |
+
obsoletes_dist: list[str]
|
| 107 |
+
requires_python: str
|
| 108 |
+
requires_external: list[str]
|
| 109 |
+
project_urls: dict[str, str]
|
| 110 |
+
|
| 111 |
+
# Metadata 2.0
|
| 112 |
+
# PEP 426 attempted to completely revamp the metadata format
|
| 113 |
+
# but got stuck without ever being able to build consensus on
|
| 114 |
+
# it and ultimately ended up withdrawn.
|
| 115 |
+
#
|
| 116 |
+
# However, a number of tools had started emitting METADATA with
|
| 117 |
+
# `2.0` Metadata-Version, so for historical reasons, this version
|
| 118 |
+
# was skipped.
|
| 119 |
+
|
| 120 |
+
# Metadata 2.1 - PEP 566
|
| 121 |
+
description_content_type: str
|
| 122 |
+
provides_extra: list[str]
|
| 123 |
+
|
| 124 |
+
# Metadata 2.2 - PEP 643
|
| 125 |
+
dynamic: list[str]
|
| 126 |
+
|
| 127 |
+
# Metadata 2.3 - PEP 685
|
| 128 |
+
# No new fields were added in PEP 685, just some edge case were
|
| 129 |
+
# tightened up to provide better interoptability.
|
| 130 |
+
|
| 131 |
+
# Metadata 2.4 - PEP 639
|
| 132 |
+
license_expression: str
|
| 133 |
+
license_files: list[str]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
_STRING_FIELDS = {
|
| 137 |
+
"author",
|
| 138 |
+
"author_email",
|
| 139 |
+
"description",
|
| 140 |
+
"description_content_type",
|
| 141 |
+
"download_url",
|
| 142 |
+
"home_page",
|
| 143 |
+
"license",
|
| 144 |
+
"license_expression",
|
| 145 |
+
"maintainer",
|
| 146 |
+
"maintainer_email",
|
| 147 |
+
"metadata_version",
|
| 148 |
+
"name",
|
| 149 |
+
"requires_python",
|
| 150 |
+
"summary",
|
| 151 |
+
"version",
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
_LIST_FIELDS = {
|
| 155 |
+
"classifiers",
|
| 156 |
+
"dynamic",
|
| 157 |
+
"license_files",
|
| 158 |
+
"obsoletes",
|
| 159 |
+
"obsoletes_dist",
|
| 160 |
+
"platforms",
|
| 161 |
+
"provides",
|
| 162 |
+
"provides_dist",
|
| 163 |
+
"provides_extra",
|
| 164 |
+
"requires",
|
| 165 |
+
"requires_dist",
|
| 166 |
+
"requires_external",
|
| 167 |
+
"supported_platforms",
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
_DICT_FIELDS = {
|
| 171 |
+
"project_urls",
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _parse_keywords(data: str) -> list[str]:
|
| 176 |
+
"""Split a string of comma-separated keywords into a list of keywords."""
|
| 177 |
+
return [k.strip() for k in data.split(",")]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _parse_project_urls(data: list[str]) -> dict[str, str]:
|
| 181 |
+
"""Parse a list of label/URL string pairings separated by a comma."""
|
| 182 |
+
urls = {}
|
| 183 |
+
for pair in data:
|
| 184 |
+
# Our logic is slightly tricky here as we want to try and do
|
| 185 |
+
# *something* reasonable with malformed data.
|
| 186 |
+
#
|
| 187 |
+
# The main thing that we have to worry about, is data that does
|
| 188 |
+
# not have a ',' at all to split the label from the Value. There
|
| 189 |
+
# isn't a singular right answer here, and we will fail validation
|
| 190 |
+
# later on (if the caller is validating) so it doesn't *really*
|
| 191 |
+
# matter, but since the missing value has to be an empty str
|
| 192 |
+
# and our return value is dict[str, str], if we let the key
|
| 193 |
+
# be the missing value, then they'd have multiple '' values that
|
| 194 |
+
# overwrite each other in a accumulating dict.
|
| 195 |
+
#
|
| 196 |
+
# The other potentional issue is that it's possible to have the
|
| 197 |
+
# same label multiple times in the metadata, with no solid "right"
|
| 198 |
+
# answer with what to do in that case. As such, we'll do the only
|
| 199 |
+
# thing we can, which is treat the field as unparseable and add it
|
| 200 |
+
# to our list of unparsed fields.
|
| 201 |
+
parts = [p.strip() for p in pair.split(",", 1)]
|
| 202 |
+
parts.extend([""] * (max(0, 2 - len(parts)))) # Ensure 2 items
|
| 203 |
+
|
| 204 |
+
# TODO: The spec doesn't say anything about if the keys should be
|
| 205 |
+
# considered case sensitive or not... logically they should
|
| 206 |
+
# be case-preserving and case-insensitive, but doing that
|
| 207 |
+
# would open up more cases where we might have duplicate
|
| 208 |
+
# entries.
|
| 209 |
+
label, url = parts
|
| 210 |
+
if label in urls:
|
| 211 |
+
# The label already exists in our set of urls, so this field
|
| 212 |
+
# is unparseable, and we can just add the whole thing to our
|
| 213 |
+
# unparseable data and stop processing it.
|
| 214 |
+
raise KeyError("duplicate labels in project urls")
|
| 215 |
+
urls[label] = url
|
| 216 |
+
|
| 217 |
+
return urls
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _get_payload(msg: email.message.Message, source: bytes | str) -> str:
|
| 221 |
+
"""Get the body of the message."""
|
| 222 |
+
# If our source is a str, then our caller has managed encodings for us,
|
| 223 |
+
# and we don't need to deal with it.
|
| 224 |
+
if isinstance(source, str):
|
| 225 |
+
payload = msg.get_payload()
|
| 226 |
+
assert isinstance(payload, str)
|
| 227 |
+
return payload
|
| 228 |
+
# If our source is a bytes, then we're managing the encoding and we need
|
| 229 |
+
# to deal with it.
|
| 230 |
+
else:
|
| 231 |
+
bpayload = msg.get_payload(decode=True)
|
| 232 |
+
assert isinstance(bpayload, bytes)
|
| 233 |
+
try:
|
| 234 |
+
return bpayload.decode("utf8", "strict")
|
| 235 |
+
except UnicodeDecodeError as exc:
|
| 236 |
+
raise ValueError("payload in an invalid encoding") from exc
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# The various parse_FORMAT functions here are intended to be as lenient as
|
| 240 |
+
# possible in their parsing, while still returning a correctly typed
|
| 241 |
+
# RawMetadata.
|
| 242 |
+
#
|
| 243 |
+
# To aid in this, we also generally want to do as little touching of the
|
| 244 |
+
# data as possible, except where there are possibly some historic holdovers
|
| 245 |
+
# that make valid data awkward to work with.
|
| 246 |
+
#
|
| 247 |
+
# While this is a lower level, intermediate format than our ``Metadata``
|
| 248 |
+
# class, some light touch ups can make a massive difference in usability.
|
| 249 |
+
|
| 250 |
+
# Map METADATA fields to RawMetadata.
|
| 251 |
+
_EMAIL_TO_RAW_MAPPING = {
|
| 252 |
+
"author": "author",
|
| 253 |
+
"author-email": "author_email",
|
| 254 |
+
"classifier": "classifiers",
|
| 255 |
+
"description": "description",
|
| 256 |
+
"description-content-type": "description_content_type",
|
| 257 |
+
"download-url": "download_url",
|
| 258 |
+
"dynamic": "dynamic",
|
| 259 |
+
"home-page": "home_page",
|
| 260 |
+
"keywords": "keywords",
|
| 261 |
+
"license": "license",
|
| 262 |
+
"license-expression": "license_expression",
|
| 263 |
+
"license-file": "license_files",
|
| 264 |
+
"maintainer": "maintainer",
|
| 265 |
+
"maintainer-email": "maintainer_email",
|
| 266 |
+
"metadata-version": "metadata_version",
|
| 267 |
+
"name": "name",
|
| 268 |
+
"obsoletes": "obsoletes",
|
| 269 |
+
"obsoletes-dist": "obsoletes_dist",
|
| 270 |
+
"platform": "platforms",
|
| 271 |
+
"project-url": "project_urls",
|
| 272 |
+
"provides": "provides",
|
| 273 |
+
"provides-dist": "provides_dist",
|
| 274 |
+
"provides-extra": "provides_extra",
|
| 275 |
+
"requires": "requires",
|
| 276 |
+
"requires-dist": "requires_dist",
|
| 277 |
+
"requires-external": "requires_external",
|
| 278 |
+
"requires-python": "requires_python",
|
| 279 |
+
"summary": "summary",
|
| 280 |
+
"supported-platform": "supported_platforms",
|
| 281 |
+
"version": "version",
|
| 282 |
+
}
|
| 283 |
+
_RAW_TO_EMAIL_MAPPING = {raw: email for email, raw in _EMAIL_TO_RAW_MAPPING.items()}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def parse_email(data: bytes | str) -> tuple[RawMetadata, dict[str, list[str]]]:
|
| 287 |
+
"""Parse a distribution's metadata stored as email headers (e.g. from ``METADATA``).
|
| 288 |
+
|
| 289 |
+
This function returns a two-item tuple of dicts. The first dict is of
|
| 290 |
+
recognized fields from the core metadata specification. Fields that can be
|
| 291 |
+
parsed and translated into Python's built-in types are converted
|
| 292 |
+
appropriately. All other fields are left as-is. Fields that are allowed to
|
| 293 |
+
appear multiple times are stored as lists.
|
| 294 |
+
|
| 295 |
+
The second dict contains all other fields from the metadata. This includes
|
| 296 |
+
any unrecognized fields. It also includes any fields which are expected to
|
| 297 |
+
be parsed into a built-in type but were not formatted appropriately. Finally,
|
| 298 |
+
any fields that are expected to appear only once but are repeated are
|
| 299 |
+
included in this dict.
|
| 300 |
+
|
| 301 |
+
"""
|
| 302 |
+
raw: dict[str, str | list[str] | dict[str, str]] = {}
|
| 303 |
+
unparsed: dict[str, list[str]] = {}
|
| 304 |
+
|
| 305 |
+
if isinstance(data, str):
|
| 306 |
+
parsed = email.parser.Parser(policy=email.policy.compat32).parsestr(data)
|
| 307 |
+
else:
|
| 308 |
+
parsed = email.parser.BytesParser(policy=email.policy.compat32).parsebytes(data)
|
| 309 |
+
|
| 310 |
+
# We have to wrap parsed.keys() in a set, because in the case of multiple
|
| 311 |
+
# values for a key (a list), the key will appear multiple times in the
|
| 312 |
+
# list of keys, but we're avoiding that by using get_all().
|
| 313 |
+
for name in frozenset(parsed.keys()):
|
| 314 |
+
# Header names in RFC are case insensitive, so we'll normalize to all
|
| 315 |
+
# lower case to make comparisons easier.
|
| 316 |
+
name = name.lower()
|
| 317 |
+
|
| 318 |
+
# We use get_all() here, even for fields that aren't multiple use,
|
| 319 |
+
# because otherwise someone could have e.g. two Name fields, and we
|
| 320 |
+
# would just silently ignore it rather than doing something about it.
|
| 321 |
+
headers = parsed.get_all(name) or []
|
| 322 |
+
|
| 323 |
+
# The way the email module works when parsing bytes is that it
|
| 324 |
+
# unconditionally decodes the bytes as ascii using the surrogateescape
|
| 325 |
+
# handler. When you pull that data back out (such as with get_all() ),
|
| 326 |
+
# it looks to see if the str has any surrogate escapes, and if it does
|
| 327 |
+
# it wraps it in a Header object instead of returning the string.
|
| 328 |
+
#
|
| 329 |
+
# As such, we'll look for those Header objects, and fix up the encoding.
|
| 330 |
+
value = []
|
| 331 |
+
# Flag if we have run into any issues processing the headers, thus
|
| 332 |
+
# signalling that the data belongs in 'unparsed'.
|
| 333 |
+
valid_encoding = True
|
| 334 |
+
for h in headers:
|
| 335 |
+
# It's unclear if this can return more types than just a Header or
|
| 336 |
+
# a str, so we'll just assert here to make sure.
|
| 337 |
+
assert isinstance(h, (email.header.Header, str))
|
| 338 |
+
|
| 339 |
+
# If it's a header object, we need to do our little dance to get
|
| 340 |
+
# the real data out of it. In cases where there is invalid data
|
| 341 |
+
# we're going to end up with mojibake, but there's no obvious, good
|
| 342 |
+
# way around that without reimplementing parts of the Header object
|
| 343 |
+
# ourselves.
|
| 344 |
+
#
|
| 345 |
+
# That should be fine since, if mojibacked happens, this key is
|
| 346 |
+
# going into the unparsed dict anyways.
|
| 347 |
+
if isinstance(h, email.header.Header):
|
| 348 |
+
# The Header object stores it's data as chunks, and each chunk
|
| 349 |
+
# can be independently encoded, so we'll need to check each
|
| 350 |
+
# of them.
|
| 351 |
+
chunks: list[tuple[bytes, str | None]] = []
|
| 352 |
+
for bin, encoding in email.header.decode_header(h):
|
| 353 |
+
try:
|
| 354 |
+
bin.decode("utf8", "strict")
|
| 355 |
+
except UnicodeDecodeError:
|
| 356 |
+
# Enable mojibake.
|
| 357 |
+
encoding = "latin1"
|
| 358 |
+
valid_encoding = False
|
| 359 |
+
else:
|
| 360 |
+
encoding = "utf8"
|
| 361 |
+
chunks.append((bin, encoding))
|
| 362 |
+
|
| 363 |
+
# Turn our chunks back into a Header object, then let that
|
| 364 |
+
# Header object do the right thing to turn them into a
|
| 365 |
+
# string for us.
|
| 366 |
+
value.append(str(email.header.make_header(chunks)))
|
| 367 |
+
# This is already a string, so just add it.
|
| 368 |
+
else:
|
| 369 |
+
value.append(h)
|
| 370 |
+
|
| 371 |
+
# We've processed all of our values to get them into a list of str,
|
| 372 |
+
# but we may have mojibake data, in which case this is an unparsed
|
| 373 |
+
# field.
|
| 374 |
+
if not valid_encoding:
|
| 375 |
+
unparsed[name] = value
|
| 376 |
+
continue
|
| 377 |
+
|
| 378 |
+
raw_name = _EMAIL_TO_RAW_MAPPING.get(name)
|
| 379 |
+
if raw_name is None:
|
| 380 |
+
# This is a bit of a weird situation, we've encountered a key that
|
| 381 |
+
# we don't know what it means, so we don't know whether it's meant
|
| 382 |
+
# to be a list or not.
|
| 383 |
+
#
|
| 384 |
+
# Since we can't really tell one way or another, we'll just leave it
|
| 385 |
+
# as a list, even though it may be a single item list, because that's
|
| 386 |
+
# what makes the most sense for email headers.
|
| 387 |
+
unparsed[name] = value
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
# If this is one of our string fields, then we'll check to see if our
|
| 391 |
+
# value is a list of a single item. If it is then we'll assume that
|
| 392 |
+
# it was emitted as a single string, and unwrap the str from inside
|
| 393 |
+
# the list.
|
| 394 |
+
#
|
| 395 |
+
# If it's any other kind of data, then we haven't the faintest clue
|
| 396 |
+
# what we should parse it as, and we have to just add it to our list
|
| 397 |
+
# of unparsed stuff.
|
| 398 |
+
if raw_name in _STRING_FIELDS and len(value) == 1:
|
| 399 |
+
raw[raw_name] = value[0]
|
| 400 |
+
# If this is one of our list of string fields, then we can just assign
|
| 401 |
+
# the value, since email *only* has strings, and our get_all() call
|
| 402 |
+
# above ensures that this is a list.
|
| 403 |
+
elif raw_name in _LIST_FIELDS:
|
| 404 |
+
raw[raw_name] = value
|
| 405 |
+
# Special Case: Keywords
|
| 406 |
+
# The keywords field is implemented in the metadata spec as a str,
|
| 407 |
+
# but it conceptually is a list of strings, and is serialized using
|
| 408 |
+
# ", ".join(keywords), so we'll do some light data massaging to turn
|
| 409 |
+
# this into what it logically is.
|
| 410 |
+
elif raw_name == "keywords" and len(value) == 1:
|
| 411 |
+
raw[raw_name] = _parse_keywords(value[0])
|
| 412 |
+
# Special Case: Project-URL
|
| 413 |
+
# The project urls is implemented in the metadata spec as a list of
|
| 414 |
+
# specially-formatted strings that represent a key and a value, which
|
| 415 |
+
# is fundamentally a mapping, however the email format doesn't support
|
| 416 |
+
# mappings in a sane way, so it was crammed into a list of strings
|
| 417 |
+
# instead.
|
| 418 |
+
#
|
| 419 |
+
# We will do a little light data massaging to turn this into a map as
|
| 420 |
+
# it logically should be.
|
| 421 |
+
elif raw_name == "project_urls":
|
| 422 |
+
try:
|
| 423 |
+
raw[raw_name] = _parse_project_urls(value)
|
| 424 |
+
except KeyError:
|
| 425 |
+
unparsed[name] = value
|
| 426 |
+
# Nothing that we've done has managed to parse this, so it'll just
|
| 427 |
+
# throw it in our unparseable data and move on.
|
| 428 |
+
else:
|
| 429 |
+
unparsed[name] = value
|
| 430 |
+
|
| 431 |
+
# We need to support getting the Description from the message payload in
|
| 432 |
+
# addition to getting it from the the headers. This does mean, though, there
|
| 433 |
+
# is the possibility of it being set both ways, in which case we put both
|
| 434 |
+
# in 'unparsed' since we don't know which is right.
|
| 435 |
+
try:
|
| 436 |
+
payload = _get_payload(parsed, data)
|
| 437 |
+
except ValueError:
|
| 438 |
+
unparsed.setdefault("description", []).append(
|
| 439 |
+
parsed.get_payload(decode=isinstance(data, bytes)) # type: ignore[call-overload]
|
| 440 |
+
)
|
| 441 |
+
else:
|
| 442 |
+
if payload:
|
| 443 |
+
# Check to see if we've already got a description, if so then both
|
| 444 |
+
# it, and this body move to unparseable.
|
| 445 |
+
if "description" in raw:
|
| 446 |
+
description_header = cast(str, raw.pop("description"))
|
| 447 |
+
unparsed.setdefault("description", []).extend(
|
| 448 |
+
[description_header, payload]
|
| 449 |
+
)
|
| 450 |
+
elif "description" in unparsed:
|
| 451 |
+
unparsed["description"].append(payload)
|
| 452 |
+
else:
|
| 453 |
+
raw["description"] = payload
|
| 454 |
+
|
| 455 |
+
# We need to cast our `raw` to a metadata, because a TypedDict only support
|
| 456 |
+
# literal key names, but we're computing our key names on purpose, but the
|
| 457 |
+
# way this function is implemented, our `TypedDict` can only have valid key
|
| 458 |
+
# names.
|
| 459 |
+
return cast(RawMetadata, raw), unparsed
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
_NOT_FOUND = object()
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# Keep the two values in sync.
|
| 466 |
+
_VALID_METADATA_VERSIONS = ["1.0", "1.1", "1.2", "2.1", "2.2", "2.3", "2.4"]
|
| 467 |
+
_MetadataVersion = Literal["1.0", "1.1", "1.2", "2.1", "2.2", "2.3", "2.4"]
|
| 468 |
+
|
| 469 |
+
_REQUIRED_ATTRS = frozenset(["metadata_version", "name", "version"])
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class _Validator(Generic[T]):
|
| 473 |
+
"""Validate a metadata field.
|
| 474 |
+
|
| 475 |
+
All _process_*() methods correspond to a core metadata field. The method is
|
| 476 |
+
called with the field's raw value. If the raw value is valid it is returned
|
| 477 |
+
in its "enriched" form (e.g. ``version.Version`` for the ``Version`` field).
|
| 478 |
+
If the raw value is invalid, :exc:`InvalidMetadata` is raised (with a cause
|
| 479 |
+
as appropriate).
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
name: str
|
| 483 |
+
raw_name: str
|
| 484 |
+
added: _MetadataVersion
|
| 485 |
+
|
| 486 |
+
def __init__(
|
| 487 |
+
self,
|
| 488 |
+
*,
|
| 489 |
+
added: _MetadataVersion = "1.0",
|
| 490 |
+
) -> None:
|
| 491 |
+
self.added = added
|
| 492 |
+
|
| 493 |
+
def __set_name__(self, _owner: Metadata, name: str) -> None:
|
| 494 |
+
self.name = name
|
| 495 |
+
self.raw_name = _RAW_TO_EMAIL_MAPPING[name]
|
| 496 |
+
|
| 497 |
+
def __get__(self, instance: Metadata, _owner: type[Metadata]) -> T:
|
| 498 |
+
# With Python 3.8, the caching can be replaced with functools.cached_property().
|
| 499 |
+
# No need to check the cache as attribute lookup will resolve into the
|
| 500 |
+
# instance's __dict__ before __get__ is called.
|
| 501 |
+
cache = instance.__dict__
|
| 502 |
+
value = instance._raw.get(self.name)
|
| 503 |
+
|
| 504 |
+
# To make the _process_* methods easier, we'll check if the value is None
|
| 505 |
+
# and if this field is NOT a required attribute, and if both of those
|
| 506 |
+
# things are true, we'll skip the the converter. This will mean that the
|
| 507 |
+
# converters never have to deal with the None union.
|
| 508 |
+
if self.name in _REQUIRED_ATTRS or value is not None:
|
| 509 |
+
try:
|
| 510 |
+
converter: Callable[[Any], T] = getattr(self, f"_process_{self.name}")
|
| 511 |
+
except AttributeError:
|
| 512 |
+
pass
|
| 513 |
+
else:
|
| 514 |
+
value = converter(value)
|
| 515 |
+
|
| 516 |
+
cache[self.name] = value
|
| 517 |
+
try:
|
| 518 |
+
del instance._raw[self.name] # type: ignore[misc]
|
| 519 |
+
except KeyError:
|
| 520 |
+
pass
|
| 521 |
+
|
| 522 |
+
return cast(T, value)
|
| 523 |
+
|
| 524 |
+
def _invalid_metadata(
|
| 525 |
+
self, msg: str, cause: Exception | None = None
|
| 526 |
+
) -> InvalidMetadata:
|
| 527 |
+
exc = InvalidMetadata(
|
| 528 |
+
self.raw_name, msg.format_map({"field": repr(self.raw_name)})
|
| 529 |
+
)
|
| 530 |
+
exc.__cause__ = cause
|
| 531 |
+
return exc
|
| 532 |
+
|
| 533 |
+
def _process_metadata_version(self, value: str) -> _MetadataVersion:
|
| 534 |
+
# Implicitly makes Metadata-Version required.
|
| 535 |
+
if value not in _VALID_METADATA_VERSIONS:
|
| 536 |
+
raise self._invalid_metadata(f"{value!r} is not a valid metadata version")
|
| 537 |
+
return cast(_MetadataVersion, value)
|
| 538 |
+
|
| 539 |
+
def _process_name(self, value: str) -> str:
|
| 540 |
+
if not value:
|
| 541 |
+
raise self._invalid_metadata("{field} is a required field")
|
| 542 |
+
# Validate the name as a side-effect.
|
| 543 |
+
try:
|
| 544 |
+
utils.canonicalize_name(value, validate=True)
|
| 545 |
+
except utils.InvalidName as exc:
|
| 546 |
+
raise self._invalid_metadata(
|
| 547 |
+
f"{value!r} is invalid for {{field}}", cause=exc
|
| 548 |
+
) from exc
|
| 549 |
+
else:
|
| 550 |
+
return value
|
| 551 |
+
|
| 552 |
+
def _process_version(self, value: str) -> version_module.Version:
|
| 553 |
+
if not value:
|
| 554 |
+
raise self._invalid_metadata("{field} is a required field")
|
| 555 |
+
try:
|
| 556 |
+
return version_module.parse(value)
|
| 557 |
+
except version_module.InvalidVersion as exc:
|
| 558 |
+
raise self._invalid_metadata(
|
| 559 |
+
f"{value!r} is invalid for {{field}}", cause=exc
|
| 560 |
+
) from exc
|
| 561 |
+
|
| 562 |
+
def _process_summary(self, value: str) -> str:
|
| 563 |
+
"""Check the field contains no newlines."""
|
| 564 |
+
if "\n" in value:
|
| 565 |
+
raise self._invalid_metadata("{field} must be a single line")
|
| 566 |
+
return value
|
| 567 |
+
|
| 568 |
+
def _process_description_content_type(self, value: str) -> str:
|
| 569 |
+
content_types = {"text/plain", "text/x-rst", "text/markdown"}
|
| 570 |
+
message = email.message.EmailMessage()
|
| 571 |
+
message["content-type"] = value
|
| 572 |
+
|
| 573 |
+
content_type, parameters = (
|
| 574 |
+
# Defaults to `text/plain` if parsing failed.
|
| 575 |
+
message.get_content_type().lower(),
|
| 576 |
+
message["content-type"].params,
|
| 577 |
+
)
|
| 578 |
+
# Check if content-type is valid or defaulted to `text/plain` and thus was
|
| 579 |
+
# not parseable.
|
| 580 |
+
if content_type not in content_types or content_type not in value.lower():
|
| 581 |
+
raise self._invalid_metadata(
|
| 582 |
+
f"{{field}} must be one of {list(content_types)}, not {value!r}"
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
charset = parameters.get("charset", "UTF-8")
|
| 586 |
+
if charset != "UTF-8":
|
| 587 |
+
raise self._invalid_metadata(
|
| 588 |
+
f"{{field}} can only specify the UTF-8 charset, not {list(charset)}"
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
markdown_variants = {"GFM", "CommonMark"}
|
| 592 |
+
variant = parameters.get("variant", "GFM") # Use an acceptable default.
|
| 593 |
+
if content_type == "text/markdown" and variant not in markdown_variants:
|
| 594 |
+
raise self._invalid_metadata(
|
| 595 |
+
f"valid Markdown variants for {{field}} are {list(markdown_variants)}, "
|
| 596 |
+
f"not {variant!r}",
|
| 597 |
+
)
|
| 598 |
+
return value
|
| 599 |
+
|
| 600 |
+
def _process_dynamic(self, value: list[str]) -> list[str]:
|
| 601 |
+
for dynamic_field in map(str.lower, value):
|
| 602 |
+
if dynamic_field in {"name", "version", "metadata-version"}:
|
| 603 |
+
raise self._invalid_metadata(
|
| 604 |
+
f"{dynamic_field!r} is not allowed as a dynamic field"
|
| 605 |
+
)
|
| 606 |
+
elif dynamic_field not in _EMAIL_TO_RAW_MAPPING:
|
| 607 |
+
raise self._invalid_metadata(
|
| 608 |
+
f"{dynamic_field!r} is not a valid dynamic field"
|
| 609 |
+
)
|
| 610 |
+
return list(map(str.lower, value))
|
| 611 |
+
|
| 612 |
+
def _process_provides_extra(
|
| 613 |
+
self,
|
| 614 |
+
value: list[str],
|
| 615 |
+
) -> list[utils.NormalizedName]:
|
| 616 |
+
normalized_names = []
|
| 617 |
+
try:
|
| 618 |
+
for name in value:
|
| 619 |
+
normalized_names.append(utils.canonicalize_name(name, validate=True))
|
| 620 |
+
except utils.InvalidName as exc:
|
| 621 |
+
raise self._invalid_metadata(
|
| 622 |
+
f"{name!r} is invalid for {{field}}", cause=exc
|
| 623 |
+
) from exc
|
| 624 |
+
else:
|
| 625 |
+
return normalized_names
|
| 626 |
+
|
| 627 |
+
def _process_requires_python(self, value: str) -> specifiers.SpecifierSet:
|
| 628 |
+
try:
|
| 629 |
+
return specifiers.SpecifierSet(value)
|
| 630 |
+
except specifiers.InvalidSpecifier as exc:
|
| 631 |
+
raise self._invalid_metadata(
|
| 632 |
+
f"{value!r} is invalid for {{field}}", cause=exc
|
| 633 |
+
) from exc
|
| 634 |
+
|
| 635 |
+
def _process_requires_dist(
|
| 636 |
+
self,
|
| 637 |
+
value: list[str],
|
| 638 |
+
) -> list[requirements.Requirement]:
|
| 639 |
+
reqs = []
|
| 640 |
+
try:
|
| 641 |
+
for req in value:
|
| 642 |
+
reqs.append(requirements.Requirement(req))
|
| 643 |
+
except requirements.InvalidRequirement as exc:
|
| 644 |
+
raise self._invalid_metadata(
|
| 645 |
+
f"{req!r} is invalid for {{field}}", cause=exc
|
| 646 |
+
) from exc
|
| 647 |
+
else:
|
| 648 |
+
return reqs
|
| 649 |
+
|
| 650 |
+
def _process_license_expression(
|
| 651 |
+
self, value: str
|
| 652 |
+
) -> NormalizedLicenseExpression | None:
|
| 653 |
+
try:
|
| 654 |
+
return licenses.canonicalize_license_expression(value)
|
| 655 |
+
except ValueError as exc:
|
| 656 |
+
raise self._invalid_metadata(
|
| 657 |
+
f"{value!r} is invalid for {{field}}", cause=exc
|
| 658 |
+
) from exc
|
| 659 |
+
|
| 660 |
+
def _process_license_files(self, value: list[str]) -> list[str]:
|
| 661 |
+
paths = []
|
| 662 |
+
for path in value:
|
| 663 |
+
if ".." in path:
|
| 664 |
+
raise self._invalid_metadata(
|
| 665 |
+
f"{path!r} is invalid for {{field}}, "
|
| 666 |
+
"parent directory indicators are not allowed"
|
| 667 |
+
)
|
| 668 |
+
if "*" in path:
|
| 669 |
+
raise self._invalid_metadata(
|
| 670 |
+
f"{path!r} is invalid for {{field}}, paths must be resolved"
|
| 671 |
+
)
|
| 672 |
+
if (
|
| 673 |
+
pathlib.PurePosixPath(path).is_absolute()
|
| 674 |
+
or pathlib.PureWindowsPath(path).is_absolute()
|
| 675 |
+
):
|
| 676 |
+
raise self._invalid_metadata(
|
| 677 |
+
f"{path!r} is invalid for {{field}}, paths must be relative"
|
| 678 |
+
)
|
| 679 |
+
if pathlib.PureWindowsPath(path).as_posix() != path:
|
| 680 |
+
raise self._invalid_metadata(
|
| 681 |
+
f"{path!r} is invalid for {{field}}, "
|
| 682 |
+
"paths must use '/' delimiter"
|
| 683 |
+
)
|
| 684 |
+
paths.append(path)
|
| 685 |
+
return paths
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
class Metadata:
|
| 689 |
+
"""Representation of distribution metadata.
|
| 690 |
+
|
| 691 |
+
Compared to :class:`RawMetadata`, this class provides objects representing
|
| 692 |
+
metadata fields instead of only using built-in types. Any invalid metadata
|
| 693 |
+
will cause :exc:`InvalidMetadata` to be raised (with a
|
| 694 |
+
:py:attr:`~BaseException.__cause__` attribute as appropriate).
|
| 695 |
+
"""
|
| 696 |
+
|
| 697 |
+
_raw: RawMetadata
|
| 698 |
+
|
| 699 |
+
@classmethod
|
| 700 |
+
def from_raw(cls, data: RawMetadata, *, validate: bool = True) -> Metadata:
|
| 701 |
+
"""Create an instance from :class:`RawMetadata`.
|
| 702 |
+
|
| 703 |
+
If *validate* is true, all metadata will be validated. All exceptions
|
| 704 |
+
related to validation will be gathered and raised as an :class:`ExceptionGroup`.
|
| 705 |
+
"""
|
| 706 |
+
ins = cls()
|
| 707 |
+
ins._raw = data.copy() # Mutations occur due to caching enriched values.
|
| 708 |
+
|
| 709 |
+
if validate:
|
| 710 |
+
exceptions: list[Exception] = []
|
| 711 |
+
try:
|
| 712 |
+
metadata_version = ins.metadata_version
|
| 713 |
+
metadata_age = _VALID_METADATA_VERSIONS.index(metadata_version)
|
| 714 |
+
except InvalidMetadata as metadata_version_exc:
|
| 715 |
+
exceptions.append(metadata_version_exc)
|
| 716 |
+
metadata_version = None
|
| 717 |
+
|
| 718 |
+
# Make sure to check for the fields that are present, the required
|
| 719 |
+
# fields (so their absence can be reported).
|
| 720 |
+
fields_to_check = frozenset(ins._raw) | _REQUIRED_ATTRS
|
| 721 |
+
# Remove fields that have already been checked.
|
| 722 |
+
fields_to_check -= {"metadata_version"}
|
| 723 |
+
|
| 724 |
+
for key in fields_to_check:
|
| 725 |
+
try:
|
| 726 |
+
if metadata_version:
|
| 727 |
+
# Can't use getattr() as that triggers descriptor protocol which
|
| 728 |
+
# will fail due to no value for the instance argument.
|
| 729 |
+
try:
|
| 730 |
+
field_metadata_version = cls.__dict__[key].added
|
| 731 |
+
except KeyError:
|
| 732 |
+
exc = InvalidMetadata(key, f"unrecognized field: {key!r}")
|
| 733 |
+
exceptions.append(exc)
|
| 734 |
+
continue
|
| 735 |
+
field_age = _VALID_METADATA_VERSIONS.index(
|
| 736 |
+
field_metadata_version
|
| 737 |
+
)
|
| 738 |
+
if field_age > metadata_age:
|
| 739 |
+
field = _RAW_TO_EMAIL_MAPPING[key]
|
| 740 |
+
exc = InvalidMetadata(
|
| 741 |
+
field,
|
| 742 |
+
f"{field} introduced in metadata version "
|
| 743 |
+
f"{field_metadata_version}, not {metadata_version}",
|
| 744 |
+
)
|
| 745 |
+
exceptions.append(exc)
|
| 746 |
+
continue
|
| 747 |
+
getattr(ins, key)
|
| 748 |
+
except InvalidMetadata as exc:
|
| 749 |
+
exceptions.append(exc)
|
| 750 |
+
|
| 751 |
+
if exceptions:
|
| 752 |
+
raise ExceptionGroup("invalid metadata", exceptions)
|
| 753 |
+
|
| 754 |
+
return ins
|
| 755 |
+
|
| 756 |
+
@classmethod
|
| 757 |
+
def from_email(cls, data: bytes | str, *, validate: bool = True) -> Metadata:
|
| 758 |
+
"""Parse metadata from email headers.
|
| 759 |
+
|
| 760 |
+
If *validate* is true, the metadata will be validated. All exceptions
|
| 761 |
+
related to validation will be gathered and raised as an :class:`ExceptionGroup`.
|
| 762 |
+
"""
|
| 763 |
+
raw, unparsed = parse_email(data)
|
| 764 |
+
|
| 765 |
+
if validate:
|
| 766 |
+
exceptions: list[Exception] = []
|
| 767 |
+
for unparsed_key in unparsed:
|
| 768 |
+
if unparsed_key in _EMAIL_TO_RAW_MAPPING:
|
| 769 |
+
message = f"{unparsed_key!r} has invalid data"
|
| 770 |
+
else:
|
| 771 |
+
message = f"unrecognized field: {unparsed_key!r}"
|
| 772 |
+
exceptions.append(InvalidMetadata(unparsed_key, message))
|
| 773 |
+
|
| 774 |
+
if exceptions:
|
| 775 |
+
raise ExceptionGroup("unparsed", exceptions)
|
| 776 |
+
|
| 777 |
+
try:
|
| 778 |
+
return cls.from_raw(raw, validate=validate)
|
| 779 |
+
except ExceptionGroup as exc_group:
|
| 780 |
+
raise ExceptionGroup(
|
| 781 |
+
"invalid or unparsed metadata", exc_group.exceptions
|
| 782 |
+
) from None
|
| 783 |
+
|
| 784 |
+
metadata_version: _Validator[_MetadataVersion] = _Validator()
|
| 785 |
+
""":external:ref:`core-metadata-metadata-version`
|
| 786 |
+
(required; validated to be a valid metadata version)"""
|
| 787 |
+
# `name` is not normalized/typed to NormalizedName so as to provide access to
|
| 788 |
+
# the original/raw name.
|
| 789 |
+
name: _Validator[str] = _Validator()
|
| 790 |
+
""":external:ref:`core-metadata-name`
|
| 791 |
+
(required; validated using :func:`~packaging.utils.canonicalize_name` and its
|
| 792 |
+
*validate* parameter)"""
|
| 793 |
+
version: _Validator[version_module.Version] = _Validator()
|
| 794 |
+
""":external:ref:`core-metadata-version` (required)"""
|
| 795 |
+
dynamic: _Validator[list[str] | None] = _Validator(
|
| 796 |
+
added="2.2",
|
| 797 |
+
)
|
| 798 |
+
""":external:ref:`core-metadata-dynamic`
|
| 799 |
+
(validated against core metadata field names and lowercased)"""
|
| 800 |
+
platforms: _Validator[list[str] | None] = _Validator()
|
| 801 |
+
""":external:ref:`core-metadata-platform`"""
|
| 802 |
+
supported_platforms: _Validator[list[str] | None] = _Validator(added="1.1")
|
| 803 |
+
""":external:ref:`core-metadata-supported-platform`"""
|
| 804 |
+
summary: _Validator[str | None] = _Validator()
|
| 805 |
+
""":external:ref:`core-metadata-summary` (validated to contain no newlines)"""
|
| 806 |
+
description: _Validator[str | None] = _Validator() # TODO 2.1: can be in body
|
| 807 |
+
""":external:ref:`core-metadata-description`"""
|
| 808 |
+
description_content_type: _Validator[str | None] = _Validator(added="2.1")
|
| 809 |
+
""":external:ref:`core-metadata-description-content-type` (validated)"""
|
| 810 |
+
keywords: _Validator[list[str] | None] = _Validator()
|
| 811 |
+
""":external:ref:`core-metadata-keywords`"""
|
| 812 |
+
home_page: _Validator[str | None] = _Validator()
|
| 813 |
+
""":external:ref:`core-metadata-home-page`"""
|
| 814 |
+
download_url: _Validator[str | None] = _Validator(added="1.1")
|
| 815 |
+
""":external:ref:`core-metadata-download-url`"""
|
| 816 |
+
author: _Validator[str | None] = _Validator()
|
| 817 |
+
""":external:ref:`core-metadata-author`"""
|
| 818 |
+
author_email: _Validator[str | None] = _Validator()
|
| 819 |
+
""":external:ref:`core-metadata-author-email`"""
|
| 820 |
+
maintainer: _Validator[str | None] = _Validator(added="1.2")
|
| 821 |
+
""":external:ref:`core-metadata-maintainer`"""
|
| 822 |
+
maintainer_email: _Validator[str | None] = _Validator(added="1.2")
|
| 823 |
+
""":external:ref:`core-metadata-maintainer-email`"""
|
| 824 |
+
license: _Validator[str | None] = _Validator()
|
| 825 |
+
""":external:ref:`core-metadata-license`"""
|
| 826 |
+
license_expression: _Validator[NormalizedLicenseExpression | None] = _Validator(
|
| 827 |
+
added="2.4"
|
| 828 |
+
)
|
| 829 |
+
""":external:ref:`core-metadata-license-expression`"""
|
| 830 |
+
license_files: _Validator[list[str] | None] = _Validator(added="2.4")
|
| 831 |
+
""":external:ref:`core-metadata-license-file`"""
|
| 832 |
+
classifiers: _Validator[list[str] | None] = _Validator(added="1.1")
|
| 833 |
+
""":external:ref:`core-metadata-classifier`"""
|
| 834 |
+
requires_dist: _Validator[list[requirements.Requirement] | None] = _Validator(
|
| 835 |
+
added="1.2"
|
| 836 |
+
)
|
| 837 |
+
""":external:ref:`core-metadata-requires-dist`"""
|
| 838 |
+
requires_python: _Validator[specifiers.SpecifierSet | None] = _Validator(
|
| 839 |
+
added="1.2"
|
| 840 |
+
)
|
| 841 |
+
""":external:ref:`core-metadata-requires-python`"""
|
| 842 |
+
# Because `Requires-External` allows for non-PEP 440 version specifiers, we
|
| 843 |
+
# don't do any processing on the values.
|
| 844 |
+
requires_external: _Validator[list[str] | None] = _Validator(added="1.2")
|
| 845 |
+
""":external:ref:`core-metadata-requires-external`"""
|
| 846 |
+
project_urls: _Validator[dict[str, str] | None] = _Validator(added="1.2")
|
| 847 |
+
""":external:ref:`core-metadata-project-url`"""
|
| 848 |
+
# PEP 685 lets us raise an error if an extra doesn't pass `Name` validation
|
| 849 |
+
# regardless of metadata version.
|
| 850 |
+
provides_extra: _Validator[list[utils.NormalizedName] | None] = _Validator(
|
| 851 |
+
added="2.1",
|
| 852 |
+
)
|
| 853 |
+
""":external:ref:`core-metadata-provides-extra`"""
|
| 854 |
+
provides_dist: _Validator[list[str] | None] = _Validator(added="1.2")
|
| 855 |
+
""":external:ref:`core-metadata-provides-dist`"""
|
| 856 |
+
obsoletes_dist: _Validator[list[str] | None] = _Validator(added="1.2")
|
| 857 |
+
""":external:ref:`core-metadata-obsoletes-dist`"""
|
| 858 |
+
requires: _Validator[list[str] | None] = _Validator(added="1.1")
|
| 859 |
+
"""``Requires`` (deprecated)"""
|
| 860 |
+
provides: _Validator[list[str] | None] = _Validator(added="1.1")
|
| 861 |
+
"""``Provides`` (deprecated)"""
|
| 862 |
+
obsoletes: _Validator[list[str] | None] = _Validator(added="1.1")
|
| 863 |
+
"""``Obsoletes`` (deprecated)"""
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/py.typed
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/specifiers.py
ADDED
|
@@ -0,0 +1,1020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is dual licensed under the terms of the Apache License, Version
|
| 2 |
+
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
| 3 |
+
# for complete details.
|
| 4 |
+
"""
|
| 5 |
+
.. testsetup::
|
| 6 |
+
|
| 7 |
+
from packaging.specifiers import Specifier, SpecifierSet, InvalidSpecifier
|
| 8 |
+
from packaging.version import Version
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import abc
|
| 14 |
+
import itertools
|
| 15 |
+
import re
|
| 16 |
+
from typing import Callable, Iterable, Iterator, TypeVar, Union
|
| 17 |
+
|
| 18 |
+
from .utils import canonicalize_version
|
| 19 |
+
from .version import Version
|
| 20 |
+
|
| 21 |
+
UnparsedVersion = Union[Version, str]
|
| 22 |
+
UnparsedVersionVar = TypeVar("UnparsedVersionVar", bound=UnparsedVersion)
|
| 23 |
+
CallableOperator = Callable[[Version, str], bool]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _coerce_version(version: UnparsedVersion) -> Version:
|
| 27 |
+
if not isinstance(version, Version):
|
| 28 |
+
version = Version(version)
|
| 29 |
+
return version
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class InvalidSpecifier(ValueError):
|
| 33 |
+
"""
|
| 34 |
+
Raised when attempting to create a :class:`Specifier` with a specifier
|
| 35 |
+
string that is invalid.
|
| 36 |
+
|
| 37 |
+
>>> Specifier("lolwat")
|
| 38 |
+
Traceback (most recent call last):
|
| 39 |
+
...
|
| 40 |
+
packaging.specifiers.InvalidSpecifier: Invalid specifier: 'lolwat'
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class BaseSpecifier(metaclass=abc.ABCMeta):
|
| 45 |
+
@abc.abstractmethod
|
| 46 |
+
def __str__(self) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Returns the str representation of this Specifier-like object. This
|
| 49 |
+
should be representative of the Specifier itself.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
@abc.abstractmethod
|
| 53 |
+
def __hash__(self) -> int:
|
| 54 |
+
"""
|
| 55 |
+
Returns a hash value for this Specifier-like object.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
@abc.abstractmethod
|
| 59 |
+
def __eq__(self, other: object) -> bool:
|
| 60 |
+
"""
|
| 61 |
+
Returns a boolean representing whether or not the two Specifier-like
|
| 62 |
+
objects are equal.
|
| 63 |
+
|
| 64 |
+
:param other: The other object to check against.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
@abc.abstractmethod
|
| 69 |
+
def prereleases(self) -> bool | None:
|
| 70 |
+
"""Whether or not pre-releases as a whole are allowed.
|
| 71 |
+
|
| 72 |
+
This can be set to either ``True`` or ``False`` to explicitly enable or disable
|
| 73 |
+
prereleases or it can be set to ``None`` (the default) to use default semantics.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
@prereleases.setter
|
| 77 |
+
def prereleases(self, value: bool) -> None:
|
| 78 |
+
"""Setter for :attr:`prereleases`.
|
| 79 |
+
|
| 80 |
+
:param value: The value to set.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
@abc.abstractmethod
|
| 84 |
+
def contains(self, item: str, prereleases: bool | None = None) -> bool:
|
| 85 |
+
"""
|
| 86 |
+
Determines if the given item is contained within this specifier.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
@abc.abstractmethod
|
| 90 |
+
def filter(
|
| 91 |
+
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
|
| 92 |
+
) -> Iterator[UnparsedVersionVar]:
|
| 93 |
+
"""
|
| 94 |
+
Takes an iterable of items and filters them so that only items which
|
| 95 |
+
are contained within this specifier are allowed in it.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Specifier(BaseSpecifier):
|
| 100 |
+
"""This class abstracts handling of version specifiers.
|
| 101 |
+
|
| 102 |
+
.. tip::
|
| 103 |
+
|
| 104 |
+
It is generally not required to instantiate this manually. You should instead
|
| 105 |
+
prefer to work with :class:`SpecifierSet` instead, which can parse
|
| 106 |
+
comma-separated version specifiers (which is what package metadata contains).
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
_operator_regex_str = r"""
|
| 110 |
+
(?P<operator>(~=|==|!=|<=|>=|<|>|===))
|
| 111 |
+
"""
|
| 112 |
+
_version_regex_str = r"""
|
| 113 |
+
(?P<version>
|
| 114 |
+
(?:
|
| 115 |
+
# The identity operators allow for an escape hatch that will
|
| 116 |
+
# do an exact string match of the version you wish to install.
|
| 117 |
+
# This will not be parsed by PEP 440 and we cannot determine
|
| 118 |
+
# any semantic meaning from it. This operator is discouraged
|
| 119 |
+
# but included entirely as an escape hatch.
|
| 120 |
+
(?<====) # Only match for the identity operator
|
| 121 |
+
\s*
|
| 122 |
+
[^\s;)]* # The arbitrary version can be just about anything,
|
| 123 |
+
# we match everything except for whitespace, a
|
| 124 |
+
# semi-colon for marker support, and a closing paren
|
| 125 |
+
# since versions can be enclosed in them.
|
| 126 |
+
)
|
| 127 |
+
|
|
| 128 |
+
(?:
|
| 129 |
+
# The (non)equality operators allow for wild card and local
|
| 130 |
+
# versions to be specified so we have to define these two
|
| 131 |
+
# operators separately to enable that.
|
| 132 |
+
(?<===|!=) # Only match for equals and not equals
|
| 133 |
+
|
| 134 |
+
\s*
|
| 135 |
+
v?
|
| 136 |
+
(?:[0-9]+!)? # epoch
|
| 137 |
+
[0-9]+(?:\.[0-9]+)* # release
|
| 138 |
+
|
| 139 |
+
# You cannot use a wild card and a pre-release, post-release, a dev or
|
| 140 |
+
# local version together so group them with a | and make them optional.
|
| 141 |
+
(?:
|
| 142 |
+
\.\* # Wild card syntax of .*
|
| 143 |
+
|
|
| 144 |
+
(?: # pre release
|
| 145 |
+
[-_\.]?
|
| 146 |
+
(alpha|beta|preview|pre|a|b|c|rc)
|
| 147 |
+
[-_\.]?
|
| 148 |
+
[0-9]*
|
| 149 |
+
)?
|
| 150 |
+
(?: # post release
|
| 151 |
+
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
|
| 152 |
+
)?
|
| 153 |
+
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
|
| 154 |
+
(?:\+[a-z0-9]+(?:[-_\.][a-z0-9]+)*)? # local
|
| 155 |
+
)?
|
| 156 |
+
)
|
| 157 |
+
|
|
| 158 |
+
(?:
|
| 159 |
+
# The compatible operator requires at least two digits in the
|
| 160 |
+
# release segment.
|
| 161 |
+
(?<=~=) # Only match for the compatible operator
|
| 162 |
+
|
| 163 |
+
\s*
|
| 164 |
+
v?
|
| 165 |
+
(?:[0-9]+!)? # epoch
|
| 166 |
+
[0-9]+(?:\.[0-9]+)+ # release (We have a + instead of a *)
|
| 167 |
+
(?: # pre release
|
| 168 |
+
[-_\.]?
|
| 169 |
+
(alpha|beta|preview|pre|a|b|c|rc)
|
| 170 |
+
[-_\.]?
|
| 171 |
+
[0-9]*
|
| 172 |
+
)?
|
| 173 |
+
(?: # post release
|
| 174 |
+
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
|
| 175 |
+
)?
|
| 176 |
+
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
|
| 177 |
+
)
|
| 178 |
+
|
|
| 179 |
+
(?:
|
| 180 |
+
# All other operators only allow a sub set of what the
|
| 181 |
+
# (non)equality operators do. Specifically they do not allow
|
| 182 |
+
# local versions to be specified nor do they allow the prefix
|
| 183 |
+
# matching wild cards.
|
| 184 |
+
(?<!==|!=|~=) # We have special cases for these
|
| 185 |
+
# operators so we want to make sure they
|
| 186 |
+
# don't match here.
|
| 187 |
+
|
| 188 |
+
\s*
|
| 189 |
+
v?
|
| 190 |
+
(?:[0-9]+!)? # epoch
|
| 191 |
+
[0-9]+(?:\.[0-9]+)* # release
|
| 192 |
+
(?: # pre release
|
| 193 |
+
[-_\.]?
|
| 194 |
+
(alpha|beta|preview|pre|a|b|c|rc)
|
| 195 |
+
[-_\.]?
|
| 196 |
+
[0-9]*
|
| 197 |
+
)?
|
| 198 |
+
(?: # post release
|
| 199 |
+
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
|
| 200 |
+
)?
|
| 201 |
+
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
_regex = re.compile(
|
| 207 |
+
r"^\s*" + _operator_regex_str + _version_regex_str + r"\s*$",
|
| 208 |
+
re.VERBOSE | re.IGNORECASE,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
_operators = {
|
| 212 |
+
"~=": "compatible",
|
| 213 |
+
"==": "equal",
|
| 214 |
+
"!=": "not_equal",
|
| 215 |
+
"<=": "less_than_equal",
|
| 216 |
+
">=": "greater_than_equal",
|
| 217 |
+
"<": "less_than",
|
| 218 |
+
">": "greater_than",
|
| 219 |
+
"===": "arbitrary",
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
def __init__(self, spec: str = "", prereleases: bool | None = None) -> None:
|
| 223 |
+
"""Initialize a Specifier instance.
|
| 224 |
+
|
| 225 |
+
:param spec:
|
| 226 |
+
The string representation of a specifier which will be parsed and
|
| 227 |
+
normalized before use.
|
| 228 |
+
:param prereleases:
|
| 229 |
+
This tells the specifier if it should accept prerelease versions if
|
| 230 |
+
applicable or not. The default of ``None`` will autodetect it from the
|
| 231 |
+
given specifiers.
|
| 232 |
+
:raises InvalidSpecifier:
|
| 233 |
+
If the given specifier is invalid (i.e. bad syntax).
|
| 234 |
+
"""
|
| 235 |
+
match = self._regex.search(spec)
|
| 236 |
+
if not match:
|
| 237 |
+
raise InvalidSpecifier(f"Invalid specifier: {spec!r}")
|
| 238 |
+
|
| 239 |
+
self._spec: tuple[str, str] = (
|
| 240 |
+
match.group("operator").strip(),
|
| 241 |
+
match.group("version").strip(),
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Store whether or not this Specifier should accept prereleases
|
| 245 |
+
self._prereleases = prereleases
|
| 246 |
+
|
| 247 |
+
# https://github.com/python/mypy/pull/13475#pullrequestreview-1079784515
|
| 248 |
+
@property # type: ignore[override]
|
| 249 |
+
def prereleases(self) -> bool:
|
| 250 |
+
# If there is an explicit prereleases set for this, then we'll just
|
| 251 |
+
# blindly use that.
|
| 252 |
+
if self._prereleases is not None:
|
| 253 |
+
return self._prereleases
|
| 254 |
+
|
| 255 |
+
# Look at all of our specifiers and determine if they are inclusive
|
| 256 |
+
# operators, and if they are if they are including an explicit
|
| 257 |
+
# prerelease.
|
| 258 |
+
operator, version = self._spec
|
| 259 |
+
if operator in ["==", ">=", "<=", "~=", "===", ">", "<"]:
|
| 260 |
+
# The == specifier can include a trailing .*, if it does we
|
| 261 |
+
# want to remove before parsing.
|
| 262 |
+
if operator == "==" and version.endswith(".*"):
|
| 263 |
+
version = version[:-2]
|
| 264 |
+
|
| 265 |
+
# Parse the version, and if it is a pre-release than this
|
| 266 |
+
# specifier allows pre-releases.
|
| 267 |
+
if Version(version).is_prerelease:
|
| 268 |
+
return True
|
| 269 |
+
|
| 270 |
+
return False
|
| 271 |
+
|
| 272 |
+
@prereleases.setter
|
| 273 |
+
def prereleases(self, value: bool) -> None:
|
| 274 |
+
self._prereleases = value
|
| 275 |
+
|
| 276 |
+
@property
|
| 277 |
+
def operator(self) -> str:
|
| 278 |
+
"""The operator of this specifier.
|
| 279 |
+
|
| 280 |
+
>>> Specifier("==1.2.3").operator
|
| 281 |
+
'=='
|
| 282 |
+
"""
|
| 283 |
+
return self._spec[0]
|
| 284 |
+
|
| 285 |
+
@property
|
| 286 |
+
def version(self) -> str:
|
| 287 |
+
"""The version of this specifier.
|
| 288 |
+
|
| 289 |
+
>>> Specifier("==1.2.3").version
|
| 290 |
+
'1.2.3'
|
| 291 |
+
"""
|
| 292 |
+
return self._spec[1]
|
| 293 |
+
|
| 294 |
+
def __repr__(self) -> str:
|
| 295 |
+
"""A representation of the Specifier that shows all internal state.
|
| 296 |
+
|
| 297 |
+
>>> Specifier('>=1.0.0')
|
| 298 |
+
<Specifier('>=1.0.0')>
|
| 299 |
+
>>> Specifier('>=1.0.0', prereleases=False)
|
| 300 |
+
<Specifier('>=1.0.0', prereleases=False)>
|
| 301 |
+
>>> Specifier('>=1.0.0', prereleases=True)
|
| 302 |
+
<Specifier('>=1.0.0', prereleases=True)>
|
| 303 |
+
"""
|
| 304 |
+
pre = (
|
| 305 |
+
f", prereleases={self.prereleases!r}"
|
| 306 |
+
if self._prereleases is not None
|
| 307 |
+
else ""
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
return f"<{self.__class__.__name__}({str(self)!r}{pre})>"
|
| 311 |
+
|
| 312 |
+
def __str__(self) -> str:
|
| 313 |
+
"""A string representation of the Specifier that can be round-tripped.
|
| 314 |
+
|
| 315 |
+
>>> str(Specifier('>=1.0.0'))
|
| 316 |
+
'>=1.0.0'
|
| 317 |
+
>>> str(Specifier('>=1.0.0', prereleases=False))
|
| 318 |
+
'>=1.0.0'
|
| 319 |
+
"""
|
| 320 |
+
return "{}{}".format(*self._spec)
|
| 321 |
+
|
| 322 |
+
@property
|
| 323 |
+
def _canonical_spec(self) -> tuple[str, str]:
|
| 324 |
+
canonical_version = canonicalize_version(
|
| 325 |
+
self._spec[1],
|
| 326 |
+
strip_trailing_zero=(self._spec[0] != "~="),
|
| 327 |
+
)
|
| 328 |
+
return self._spec[0], canonical_version
|
| 329 |
+
|
| 330 |
+
def __hash__(self) -> int:
|
| 331 |
+
return hash(self._canonical_spec)
|
| 332 |
+
|
| 333 |
+
def __eq__(self, other: object) -> bool:
|
| 334 |
+
"""Whether or not the two Specifier-like objects are equal.
|
| 335 |
+
|
| 336 |
+
:param other: The other object to check against.
|
| 337 |
+
|
| 338 |
+
The value of :attr:`prereleases` is ignored.
|
| 339 |
+
|
| 340 |
+
>>> Specifier("==1.2.3") == Specifier("== 1.2.3.0")
|
| 341 |
+
True
|
| 342 |
+
>>> (Specifier("==1.2.3", prereleases=False) ==
|
| 343 |
+
... Specifier("==1.2.3", prereleases=True))
|
| 344 |
+
True
|
| 345 |
+
>>> Specifier("==1.2.3") == "==1.2.3"
|
| 346 |
+
True
|
| 347 |
+
>>> Specifier("==1.2.3") == Specifier("==1.2.4")
|
| 348 |
+
False
|
| 349 |
+
>>> Specifier("==1.2.3") == Specifier("~=1.2.3")
|
| 350 |
+
False
|
| 351 |
+
"""
|
| 352 |
+
if isinstance(other, str):
|
| 353 |
+
try:
|
| 354 |
+
other = self.__class__(str(other))
|
| 355 |
+
except InvalidSpecifier:
|
| 356 |
+
return NotImplemented
|
| 357 |
+
elif not isinstance(other, self.__class__):
|
| 358 |
+
return NotImplemented
|
| 359 |
+
|
| 360 |
+
return self._canonical_spec == other._canonical_spec
|
| 361 |
+
|
| 362 |
+
def _get_operator(self, op: str) -> CallableOperator:
|
| 363 |
+
operator_callable: CallableOperator = getattr(
|
| 364 |
+
self, f"_compare_{self._operators[op]}"
|
| 365 |
+
)
|
| 366 |
+
return operator_callable
|
| 367 |
+
|
| 368 |
+
def _compare_compatible(self, prospective: Version, spec: str) -> bool:
|
| 369 |
+
# Compatible releases have an equivalent combination of >= and ==. That
|
| 370 |
+
# is that ~=2.2 is equivalent to >=2.2,==2.*. This allows us to
|
| 371 |
+
# implement this in terms of the other specifiers instead of
|
| 372 |
+
# implementing it ourselves. The only thing we need to do is construct
|
| 373 |
+
# the other specifiers.
|
| 374 |
+
|
| 375 |
+
# We want everything but the last item in the version, but we want to
|
| 376 |
+
# ignore suffix segments.
|
| 377 |
+
prefix = _version_join(
|
| 378 |
+
list(itertools.takewhile(_is_not_suffix, _version_split(spec)))[:-1]
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Add the prefix notation to the end of our string
|
| 382 |
+
prefix += ".*"
|
| 383 |
+
|
| 384 |
+
return self._get_operator(">=")(prospective, spec) and self._get_operator("==")(
|
| 385 |
+
prospective, prefix
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def _compare_equal(self, prospective: Version, spec: str) -> bool:
|
| 389 |
+
# We need special logic to handle prefix matching
|
| 390 |
+
if spec.endswith(".*"):
|
| 391 |
+
# In the case of prefix matching we want to ignore local segment.
|
| 392 |
+
normalized_prospective = canonicalize_version(
|
| 393 |
+
prospective.public, strip_trailing_zero=False
|
| 394 |
+
)
|
| 395 |
+
# Get the normalized version string ignoring the trailing .*
|
| 396 |
+
normalized_spec = canonicalize_version(spec[:-2], strip_trailing_zero=False)
|
| 397 |
+
# Split the spec out by bangs and dots, and pretend that there is
|
| 398 |
+
# an implicit dot in between a release segment and a pre-release segment.
|
| 399 |
+
split_spec = _version_split(normalized_spec)
|
| 400 |
+
|
| 401 |
+
# Split the prospective version out by bangs and dots, and pretend
|
| 402 |
+
# that there is an implicit dot in between a release segment and
|
| 403 |
+
# a pre-release segment.
|
| 404 |
+
split_prospective = _version_split(normalized_prospective)
|
| 405 |
+
|
| 406 |
+
# 0-pad the prospective version before shortening it to get the correct
|
| 407 |
+
# shortened version.
|
| 408 |
+
padded_prospective, _ = _pad_version(split_prospective, split_spec)
|
| 409 |
+
|
| 410 |
+
# Shorten the prospective version to be the same length as the spec
|
| 411 |
+
# so that we can determine if the specifier is a prefix of the
|
| 412 |
+
# prospective version or not.
|
| 413 |
+
shortened_prospective = padded_prospective[: len(split_spec)]
|
| 414 |
+
|
| 415 |
+
return shortened_prospective == split_spec
|
| 416 |
+
else:
|
| 417 |
+
# Convert our spec string into a Version
|
| 418 |
+
spec_version = Version(spec)
|
| 419 |
+
|
| 420 |
+
# If the specifier does not have a local segment, then we want to
|
| 421 |
+
# act as if the prospective version also does not have a local
|
| 422 |
+
# segment.
|
| 423 |
+
if not spec_version.local:
|
| 424 |
+
prospective = Version(prospective.public)
|
| 425 |
+
|
| 426 |
+
return prospective == spec_version
|
| 427 |
+
|
| 428 |
+
def _compare_not_equal(self, prospective: Version, spec: str) -> bool:
|
| 429 |
+
return not self._compare_equal(prospective, spec)
|
| 430 |
+
|
| 431 |
+
def _compare_less_than_equal(self, prospective: Version, spec: str) -> bool:
|
| 432 |
+
# NB: Local version identifiers are NOT permitted in the version
|
| 433 |
+
# specifier, so local version labels can be universally removed from
|
| 434 |
+
# the prospective version.
|
| 435 |
+
return Version(prospective.public) <= Version(spec)
|
| 436 |
+
|
| 437 |
+
def _compare_greater_than_equal(self, prospective: Version, spec: str) -> bool:
|
| 438 |
+
# NB: Local version identifiers are NOT permitted in the version
|
| 439 |
+
# specifier, so local version labels can be universally removed from
|
| 440 |
+
# the prospective version.
|
| 441 |
+
return Version(prospective.public) >= Version(spec)
|
| 442 |
+
|
| 443 |
+
def _compare_less_than(self, prospective: Version, spec_str: str) -> bool:
|
| 444 |
+
# Convert our spec to a Version instance, since we'll want to work with
|
| 445 |
+
# it as a version.
|
| 446 |
+
spec = Version(spec_str)
|
| 447 |
+
|
| 448 |
+
# Check to see if the prospective version is less than the spec
|
| 449 |
+
# version. If it's not we can short circuit and just return False now
|
| 450 |
+
# instead of doing extra unneeded work.
|
| 451 |
+
if not prospective < spec:
|
| 452 |
+
return False
|
| 453 |
+
|
| 454 |
+
# This special case is here so that, unless the specifier itself
|
| 455 |
+
# includes is a pre-release version, that we do not accept pre-release
|
| 456 |
+
# versions for the version mentioned in the specifier (e.g. <3.1 should
|
| 457 |
+
# not match 3.1.dev0, but should match 3.0.dev0).
|
| 458 |
+
if not spec.is_prerelease and prospective.is_prerelease:
|
| 459 |
+
if Version(prospective.base_version) == Version(spec.base_version):
|
| 460 |
+
return False
|
| 461 |
+
|
| 462 |
+
# If we've gotten to here, it means that prospective version is both
|
| 463 |
+
# less than the spec version *and* it's not a pre-release of the same
|
| 464 |
+
# version in the spec.
|
| 465 |
+
return True
|
| 466 |
+
|
| 467 |
+
def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
|
| 468 |
+
# Convert our spec to a Version instance, since we'll want to work with
|
| 469 |
+
# it as a version.
|
| 470 |
+
spec = Version(spec_str)
|
| 471 |
+
|
| 472 |
+
# Check to see if the prospective version is greater than the spec
|
| 473 |
+
# version. If it's not we can short circuit and just return False now
|
| 474 |
+
# instead of doing extra unneeded work.
|
| 475 |
+
if not prospective > spec:
|
| 476 |
+
return False
|
| 477 |
+
|
| 478 |
+
# This special case is here so that, unless the specifier itself
|
| 479 |
+
# includes is a post-release version, that we do not accept
|
| 480 |
+
# post-release versions for the version mentioned in the specifier
|
| 481 |
+
# (e.g. >3.1 should not match 3.0.post0, but should match 3.2.post0).
|
| 482 |
+
if not spec.is_postrelease and prospective.is_postrelease:
|
| 483 |
+
if Version(prospective.base_version) == Version(spec.base_version):
|
| 484 |
+
return False
|
| 485 |
+
|
| 486 |
+
# Ensure that we do not allow a local version of the version mentioned
|
| 487 |
+
# in the specifier, which is technically greater than, to match.
|
| 488 |
+
if prospective.local is not None:
|
| 489 |
+
if Version(prospective.base_version) == Version(spec.base_version):
|
| 490 |
+
return False
|
| 491 |
+
|
| 492 |
+
# If we've gotten to here, it means that prospective version is both
|
| 493 |
+
# greater than the spec version *and* it's not a pre-release of the
|
| 494 |
+
# same version in the spec.
|
| 495 |
+
return True
|
| 496 |
+
|
| 497 |
+
def _compare_arbitrary(self, prospective: Version, spec: str) -> bool:
|
| 498 |
+
return str(prospective).lower() == str(spec).lower()
|
| 499 |
+
|
| 500 |
+
def __contains__(self, item: str | Version) -> bool:
|
| 501 |
+
"""Return whether or not the item is contained in this specifier.
|
| 502 |
+
|
| 503 |
+
:param item: The item to check for.
|
| 504 |
+
|
| 505 |
+
This is used for the ``in`` operator and behaves the same as
|
| 506 |
+
:meth:`contains` with no ``prereleases`` argument passed.
|
| 507 |
+
|
| 508 |
+
>>> "1.2.3" in Specifier(">=1.2.3")
|
| 509 |
+
True
|
| 510 |
+
>>> Version("1.2.3") in Specifier(">=1.2.3")
|
| 511 |
+
True
|
| 512 |
+
>>> "1.0.0" in Specifier(">=1.2.3")
|
| 513 |
+
False
|
| 514 |
+
>>> "1.3.0a1" in Specifier(">=1.2.3")
|
| 515 |
+
False
|
| 516 |
+
>>> "1.3.0a1" in Specifier(">=1.2.3", prereleases=True)
|
| 517 |
+
True
|
| 518 |
+
"""
|
| 519 |
+
return self.contains(item)
|
| 520 |
+
|
| 521 |
+
def contains(self, item: UnparsedVersion, prereleases: bool | None = None) -> bool:
|
| 522 |
+
"""Return whether or not the item is contained in this specifier.
|
| 523 |
+
|
| 524 |
+
:param item:
|
| 525 |
+
The item to check for, which can be a version string or a
|
| 526 |
+
:class:`Version` instance.
|
| 527 |
+
:param prereleases:
|
| 528 |
+
Whether or not to match prereleases with this Specifier. If set to
|
| 529 |
+
``None`` (the default), it uses :attr:`prereleases` to determine
|
| 530 |
+
whether or not prereleases are allowed.
|
| 531 |
+
|
| 532 |
+
>>> Specifier(">=1.2.3").contains("1.2.3")
|
| 533 |
+
True
|
| 534 |
+
>>> Specifier(">=1.2.3").contains(Version("1.2.3"))
|
| 535 |
+
True
|
| 536 |
+
>>> Specifier(">=1.2.3").contains("1.0.0")
|
| 537 |
+
False
|
| 538 |
+
>>> Specifier(">=1.2.3").contains("1.3.0a1")
|
| 539 |
+
False
|
| 540 |
+
>>> Specifier(">=1.2.3", prereleases=True).contains("1.3.0a1")
|
| 541 |
+
True
|
| 542 |
+
>>> Specifier(">=1.2.3").contains("1.3.0a1", prereleases=True)
|
| 543 |
+
True
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
# Determine if prereleases are to be allowed or not.
|
| 547 |
+
if prereleases is None:
|
| 548 |
+
prereleases = self.prereleases
|
| 549 |
+
|
| 550 |
+
# Normalize item to a Version, this allows us to have a shortcut for
|
| 551 |
+
# "2.0" in Specifier(">=2")
|
| 552 |
+
normalized_item = _coerce_version(item)
|
| 553 |
+
|
| 554 |
+
# Determine if we should be supporting prereleases in this specifier
|
| 555 |
+
# or not, if we do not support prereleases than we can short circuit
|
| 556 |
+
# logic if this version is a prereleases.
|
| 557 |
+
if normalized_item.is_prerelease and not prereleases:
|
| 558 |
+
return False
|
| 559 |
+
|
| 560 |
+
# Actually do the comparison to determine if this item is contained
|
| 561 |
+
# within this Specifier or not.
|
| 562 |
+
operator_callable: CallableOperator = self._get_operator(self.operator)
|
| 563 |
+
return operator_callable(normalized_item, self.version)
|
| 564 |
+
|
| 565 |
+
def filter(
|
| 566 |
+
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
|
| 567 |
+
) -> Iterator[UnparsedVersionVar]:
|
| 568 |
+
"""Filter items in the given iterable, that match the specifier.
|
| 569 |
+
|
| 570 |
+
:param iterable:
|
| 571 |
+
An iterable that can contain version strings and :class:`Version` instances.
|
| 572 |
+
The items in the iterable will be filtered according to the specifier.
|
| 573 |
+
:param prereleases:
|
| 574 |
+
Whether or not to allow prereleases in the returned iterator. If set to
|
| 575 |
+
``None`` (the default), it will be intelligently decide whether to allow
|
| 576 |
+
prereleases or not (based on the :attr:`prereleases` attribute, and
|
| 577 |
+
whether the only versions matching are prereleases).
|
| 578 |
+
|
| 579 |
+
This method is smarter than just ``filter(Specifier().contains, [...])``
|
| 580 |
+
because it implements the rule from :pep:`440` that a prerelease item
|
| 581 |
+
SHOULD be accepted if no other versions match the given specifier.
|
| 582 |
+
|
| 583 |
+
>>> list(Specifier(">=1.2.3").filter(["1.2", "1.3", "1.5a1"]))
|
| 584 |
+
['1.3']
|
| 585 |
+
>>> list(Specifier(">=1.2.3").filter(["1.2", "1.2.3", "1.3", Version("1.4")]))
|
| 586 |
+
['1.2.3', '1.3', <Version('1.4')>]
|
| 587 |
+
>>> list(Specifier(">=1.2.3").filter(["1.2", "1.5a1"]))
|
| 588 |
+
['1.5a1']
|
| 589 |
+
>>> list(Specifier(">=1.2.3").filter(["1.3", "1.5a1"], prereleases=True))
|
| 590 |
+
['1.3', '1.5a1']
|
| 591 |
+
>>> list(Specifier(">=1.2.3", prereleases=True).filter(["1.3", "1.5a1"]))
|
| 592 |
+
['1.3', '1.5a1']
|
| 593 |
+
"""
|
| 594 |
+
|
| 595 |
+
yielded = False
|
| 596 |
+
found_prereleases = []
|
| 597 |
+
|
| 598 |
+
kw = {"prereleases": prereleases if prereleases is not None else True}
|
| 599 |
+
|
| 600 |
+
# Attempt to iterate over all the values in the iterable and if any of
|
| 601 |
+
# them match, yield them.
|
| 602 |
+
for version in iterable:
|
| 603 |
+
parsed_version = _coerce_version(version)
|
| 604 |
+
|
| 605 |
+
if self.contains(parsed_version, **kw):
|
| 606 |
+
# If our version is a prerelease, and we were not set to allow
|
| 607 |
+
# prereleases, then we'll store it for later in case nothing
|
| 608 |
+
# else matches this specifier.
|
| 609 |
+
if parsed_version.is_prerelease and not (
|
| 610 |
+
prereleases or self.prereleases
|
| 611 |
+
):
|
| 612 |
+
found_prereleases.append(version)
|
| 613 |
+
# Either this is not a prerelease, or we should have been
|
| 614 |
+
# accepting prereleases from the beginning.
|
| 615 |
+
else:
|
| 616 |
+
yielded = True
|
| 617 |
+
yield version
|
| 618 |
+
|
| 619 |
+
# Now that we've iterated over everything, determine if we've yielded
|
| 620 |
+
# any values, and if we have not and we have any prereleases stored up
|
| 621 |
+
# then we will go ahead and yield the prereleases.
|
| 622 |
+
if not yielded and found_prereleases:
|
| 623 |
+
for version in found_prereleases:
|
| 624 |
+
yield version
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
_prefix_regex = re.compile(r"^([0-9]+)((?:a|b|c|rc)[0-9]+)$")
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def _version_split(version: str) -> list[str]:
|
| 631 |
+
"""Split version into components.
|
| 632 |
+
|
| 633 |
+
The split components are intended for version comparison. The logic does
|
| 634 |
+
not attempt to retain the original version string, so joining the
|
| 635 |
+
components back with :func:`_version_join` may not produce the original
|
| 636 |
+
version string.
|
| 637 |
+
"""
|
| 638 |
+
result: list[str] = []
|
| 639 |
+
|
| 640 |
+
epoch, _, rest = version.rpartition("!")
|
| 641 |
+
result.append(epoch or "0")
|
| 642 |
+
|
| 643 |
+
for item in rest.split("."):
|
| 644 |
+
match = _prefix_regex.search(item)
|
| 645 |
+
if match:
|
| 646 |
+
result.extend(match.groups())
|
| 647 |
+
else:
|
| 648 |
+
result.append(item)
|
| 649 |
+
return result
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def _version_join(components: list[str]) -> str:
|
| 653 |
+
"""Join split version components into a version string.
|
| 654 |
+
|
| 655 |
+
This function assumes the input came from :func:`_version_split`, where the
|
| 656 |
+
first component must be the epoch (either empty or numeric), and all other
|
| 657 |
+
components numeric.
|
| 658 |
+
"""
|
| 659 |
+
epoch, *rest = components
|
| 660 |
+
return f"{epoch}!{'.'.join(rest)}"
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def _is_not_suffix(segment: str) -> bool:
|
| 664 |
+
return not any(
|
| 665 |
+
segment.startswith(prefix) for prefix in ("dev", "a", "b", "rc", "post")
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def _pad_version(left: list[str], right: list[str]) -> tuple[list[str], list[str]]:
|
| 670 |
+
left_split, right_split = [], []
|
| 671 |
+
|
| 672 |
+
# Get the release segment of our versions
|
| 673 |
+
left_split.append(list(itertools.takewhile(lambda x: x.isdigit(), left)))
|
| 674 |
+
right_split.append(list(itertools.takewhile(lambda x: x.isdigit(), right)))
|
| 675 |
+
|
| 676 |
+
# Get the rest of our versions
|
| 677 |
+
left_split.append(left[len(left_split[0]) :])
|
| 678 |
+
right_split.append(right[len(right_split[0]) :])
|
| 679 |
+
|
| 680 |
+
# Insert our padding
|
| 681 |
+
left_split.insert(1, ["0"] * max(0, len(right_split[0]) - len(left_split[0])))
|
| 682 |
+
right_split.insert(1, ["0"] * max(0, len(left_split[0]) - len(right_split[0])))
|
| 683 |
+
|
| 684 |
+
return (
|
| 685 |
+
list(itertools.chain.from_iterable(left_split)),
|
| 686 |
+
list(itertools.chain.from_iterable(right_split)),
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class SpecifierSet(BaseSpecifier):
|
| 691 |
+
"""This class abstracts handling of a set of version specifiers.
|
| 692 |
+
|
| 693 |
+
It can be passed a single specifier (``>=3.0``), a comma-separated list of
|
| 694 |
+
specifiers (``>=3.0,!=3.1``), or no specifier at all.
|
| 695 |
+
"""
|
| 696 |
+
|
| 697 |
+
def __init__(
|
| 698 |
+
self,
|
| 699 |
+
specifiers: str | Iterable[Specifier] = "",
|
| 700 |
+
prereleases: bool | None = None,
|
| 701 |
+
) -> None:
|
| 702 |
+
"""Initialize a SpecifierSet instance.
|
| 703 |
+
|
| 704 |
+
:param specifiers:
|
| 705 |
+
The string representation of a specifier or a comma-separated list of
|
| 706 |
+
specifiers which will be parsed and normalized before use.
|
| 707 |
+
May also be an iterable of ``Specifier`` instances, which will be used
|
| 708 |
+
as is.
|
| 709 |
+
:param prereleases:
|
| 710 |
+
This tells the SpecifierSet if it should accept prerelease versions if
|
| 711 |
+
applicable or not. The default of ``None`` will autodetect it from the
|
| 712 |
+
given specifiers.
|
| 713 |
+
|
| 714 |
+
:raises InvalidSpecifier:
|
| 715 |
+
If the given ``specifiers`` are not parseable than this exception will be
|
| 716 |
+
raised.
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
if isinstance(specifiers, str):
|
| 720 |
+
# Split on `,` to break each individual specifier into its own item, and
|
| 721 |
+
# strip each item to remove leading/trailing whitespace.
|
| 722 |
+
split_specifiers = [s.strip() for s in specifiers.split(",") if s.strip()]
|
| 723 |
+
|
| 724 |
+
# Make each individual specifier a Specifier and save in a frozen set
|
| 725 |
+
# for later.
|
| 726 |
+
self._specs = frozenset(map(Specifier, split_specifiers))
|
| 727 |
+
else:
|
| 728 |
+
# Save the supplied specifiers in a frozen set.
|
| 729 |
+
self._specs = frozenset(specifiers)
|
| 730 |
+
|
| 731 |
+
# Store our prereleases value so we can use it later to determine if
|
| 732 |
+
# we accept prereleases or not.
|
| 733 |
+
self._prereleases = prereleases
|
| 734 |
+
|
| 735 |
+
@property
|
| 736 |
+
def prereleases(self) -> bool | None:
|
| 737 |
+
# If we have been given an explicit prerelease modifier, then we'll
|
| 738 |
+
# pass that through here.
|
| 739 |
+
if self._prereleases is not None:
|
| 740 |
+
return self._prereleases
|
| 741 |
+
|
| 742 |
+
# If we don't have any specifiers, and we don't have a forced value,
|
| 743 |
+
# then we'll just return None since we don't know if this should have
|
| 744 |
+
# pre-releases or not.
|
| 745 |
+
if not self._specs:
|
| 746 |
+
return None
|
| 747 |
+
|
| 748 |
+
# Otherwise we'll see if any of the given specifiers accept
|
| 749 |
+
# prereleases, if any of them do we'll return True, otherwise False.
|
| 750 |
+
return any(s.prereleases for s in self._specs)
|
| 751 |
+
|
| 752 |
+
@prereleases.setter
|
| 753 |
+
def prereleases(self, value: bool) -> None:
|
| 754 |
+
self._prereleases = value
|
| 755 |
+
|
| 756 |
+
def __repr__(self) -> str:
|
| 757 |
+
"""A representation of the specifier set that shows all internal state.
|
| 758 |
+
|
| 759 |
+
Note that the ordering of the individual specifiers within the set may not
|
| 760 |
+
match the input string.
|
| 761 |
+
|
| 762 |
+
>>> SpecifierSet('>=1.0.0,!=2.0.0')
|
| 763 |
+
<SpecifierSet('!=2.0.0,>=1.0.0')>
|
| 764 |
+
>>> SpecifierSet('>=1.0.0,!=2.0.0', prereleases=False)
|
| 765 |
+
<SpecifierSet('!=2.0.0,>=1.0.0', prereleases=False)>
|
| 766 |
+
>>> SpecifierSet('>=1.0.0,!=2.0.0', prereleases=True)
|
| 767 |
+
<SpecifierSet('!=2.0.0,>=1.0.0', prereleases=True)>
|
| 768 |
+
"""
|
| 769 |
+
pre = (
|
| 770 |
+
f", prereleases={self.prereleases!r}"
|
| 771 |
+
if self._prereleases is not None
|
| 772 |
+
else ""
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
return f"<SpecifierSet({str(self)!r}{pre})>"
|
| 776 |
+
|
| 777 |
+
def __str__(self) -> str:
|
| 778 |
+
"""A string representation of the specifier set that can be round-tripped.
|
| 779 |
+
|
| 780 |
+
Note that the ordering of the individual specifiers within the set may not
|
| 781 |
+
match the input string.
|
| 782 |
+
|
| 783 |
+
>>> str(SpecifierSet(">=1.0.0,!=1.0.1"))
|
| 784 |
+
'!=1.0.1,>=1.0.0'
|
| 785 |
+
>>> str(SpecifierSet(">=1.0.0,!=1.0.1", prereleases=False))
|
| 786 |
+
'!=1.0.1,>=1.0.0'
|
| 787 |
+
"""
|
| 788 |
+
return ",".join(sorted(str(s) for s in self._specs))
|
| 789 |
+
|
| 790 |
+
def __hash__(self) -> int:
|
| 791 |
+
return hash(self._specs)
|
| 792 |
+
|
| 793 |
+
def __and__(self, other: SpecifierSet | str) -> SpecifierSet:
|
| 794 |
+
"""Return a SpecifierSet which is a combination of the two sets.
|
| 795 |
+
|
| 796 |
+
:param other: The other object to combine with.
|
| 797 |
+
|
| 798 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1") & '<=2.0.0,!=2.0.1'
|
| 799 |
+
<SpecifierSet('!=1.0.1,!=2.0.1,<=2.0.0,>=1.0.0')>
|
| 800 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1") & SpecifierSet('<=2.0.0,!=2.0.1')
|
| 801 |
+
<SpecifierSet('!=1.0.1,!=2.0.1,<=2.0.0,>=1.0.0')>
|
| 802 |
+
"""
|
| 803 |
+
if isinstance(other, str):
|
| 804 |
+
other = SpecifierSet(other)
|
| 805 |
+
elif not isinstance(other, SpecifierSet):
|
| 806 |
+
return NotImplemented
|
| 807 |
+
|
| 808 |
+
specifier = SpecifierSet()
|
| 809 |
+
specifier._specs = frozenset(self._specs | other._specs)
|
| 810 |
+
|
| 811 |
+
if self._prereleases is None and other._prereleases is not None:
|
| 812 |
+
specifier._prereleases = other._prereleases
|
| 813 |
+
elif self._prereleases is not None and other._prereleases is None:
|
| 814 |
+
specifier._prereleases = self._prereleases
|
| 815 |
+
elif self._prereleases == other._prereleases:
|
| 816 |
+
specifier._prereleases = self._prereleases
|
| 817 |
+
else:
|
| 818 |
+
raise ValueError(
|
| 819 |
+
"Cannot combine SpecifierSets with True and False prerelease "
|
| 820 |
+
"overrides."
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
return specifier
|
| 824 |
+
|
| 825 |
+
def __eq__(self, other: object) -> bool:
|
| 826 |
+
"""Whether or not the two SpecifierSet-like objects are equal.
|
| 827 |
+
|
| 828 |
+
:param other: The other object to check against.
|
| 829 |
+
|
| 830 |
+
The value of :attr:`prereleases` is ignored.
|
| 831 |
+
|
| 832 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1") == SpecifierSet(">=1.0.0,!=1.0.1")
|
| 833 |
+
True
|
| 834 |
+
>>> (SpecifierSet(">=1.0.0,!=1.0.1", prereleases=False) ==
|
| 835 |
+
... SpecifierSet(">=1.0.0,!=1.0.1", prereleases=True))
|
| 836 |
+
True
|
| 837 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1") == ">=1.0.0,!=1.0.1"
|
| 838 |
+
True
|
| 839 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1") == SpecifierSet(">=1.0.0")
|
| 840 |
+
False
|
| 841 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1") == SpecifierSet(">=1.0.0,!=1.0.2")
|
| 842 |
+
False
|
| 843 |
+
"""
|
| 844 |
+
if isinstance(other, (str, Specifier)):
|
| 845 |
+
other = SpecifierSet(str(other))
|
| 846 |
+
elif not isinstance(other, SpecifierSet):
|
| 847 |
+
return NotImplemented
|
| 848 |
+
|
| 849 |
+
return self._specs == other._specs
|
| 850 |
+
|
| 851 |
+
def __len__(self) -> int:
|
| 852 |
+
"""Returns the number of specifiers in this specifier set."""
|
| 853 |
+
return len(self._specs)
|
| 854 |
+
|
| 855 |
+
def __iter__(self) -> Iterator[Specifier]:
|
| 856 |
+
"""
|
| 857 |
+
Returns an iterator over all the underlying :class:`Specifier` instances
|
| 858 |
+
in this specifier set.
|
| 859 |
+
|
| 860 |
+
>>> sorted(SpecifierSet(">=1.0.0,!=1.0.1"), key=str)
|
| 861 |
+
[<Specifier('!=1.0.1')>, <Specifier('>=1.0.0')>]
|
| 862 |
+
"""
|
| 863 |
+
return iter(self._specs)
|
| 864 |
+
|
| 865 |
+
def __contains__(self, item: UnparsedVersion) -> bool:
|
| 866 |
+
"""Return whether or not the item is contained in this specifier.
|
| 867 |
+
|
| 868 |
+
:param item: The item to check for.
|
| 869 |
+
|
| 870 |
+
This is used for the ``in`` operator and behaves the same as
|
| 871 |
+
:meth:`contains` with no ``prereleases`` argument passed.
|
| 872 |
+
|
| 873 |
+
>>> "1.2.3" in SpecifierSet(">=1.0.0,!=1.0.1")
|
| 874 |
+
True
|
| 875 |
+
>>> Version("1.2.3") in SpecifierSet(">=1.0.0,!=1.0.1")
|
| 876 |
+
True
|
| 877 |
+
>>> "1.0.1" in SpecifierSet(">=1.0.0,!=1.0.1")
|
| 878 |
+
False
|
| 879 |
+
>>> "1.3.0a1" in SpecifierSet(">=1.0.0,!=1.0.1")
|
| 880 |
+
False
|
| 881 |
+
>>> "1.3.0a1" in SpecifierSet(">=1.0.0,!=1.0.1", prereleases=True)
|
| 882 |
+
True
|
| 883 |
+
"""
|
| 884 |
+
return self.contains(item)
|
| 885 |
+
|
| 886 |
+
def contains(
|
| 887 |
+
self,
|
| 888 |
+
item: UnparsedVersion,
|
| 889 |
+
prereleases: bool | None = None,
|
| 890 |
+
installed: bool | None = None,
|
| 891 |
+
) -> bool:
|
| 892 |
+
"""Return whether or not the item is contained in this SpecifierSet.
|
| 893 |
+
|
| 894 |
+
:param item:
|
| 895 |
+
The item to check for, which can be a version string or a
|
| 896 |
+
:class:`Version` instance.
|
| 897 |
+
:param prereleases:
|
| 898 |
+
Whether or not to match prereleases with this SpecifierSet. If set to
|
| 899 |
+
``None`` (the default), it uses :attr:`prereleases` to determine
|
| 900 |
+
whether or not prereleases are allowed.
|
| 901 |
+
|
| 902 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.2.3")
|
| 903 |
+
True
|
| 904 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1").contains(Version("1.2.3"))
|
| 905 |
+
True
|
| 906 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.0.1")
|
| 907 |
+
False
|
| 908 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.3.0a1")
|
| 909 |
+
False
|
| 910 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1", prereleases=True).contains("1.3.0a1")
|
| 911 |
+
True
|
| 912 |
+
>>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.3.0a1", prereleases=True)
|
| 913 |
+
True
|
| 914 |
+
"""
|
| 915 |
+
# Ensure that our item is a Version instance.
|
| 916 |
+
if not isinstance(item, Version):
|
| 917 |
+
item = Version(item)
|
| 918 |
+
|
| 919 |
+
# Determine if we're forcing a prerelease or not, if we're not forcing
|
| 920 |
+
# one for this particular filter call, then we'll use whatever the
|
| 921 |
+
# SpecifierSet thinks for whether or not we should support prereleases.
|
| 922 |
+
if prereleases is None:
|
| 923 |
+
prereleases = self.prereleases
|
| 924 |
+
|
| 925 |
+
# We can determine if we're going to allow pre-releases by looking to
|
| 926 |
+
# see if any of the underlying items supports them. If none of them do
|
| 927 |
+
# and this item is a pre-release then we do not allow it and we can
|
| 928 |
+
# short circuit that here.
|
| 929 |
+
# Note: This means that 1.0.dev1 would not be contained in something
|
| 930 |
+
# like >=1.0.devabc however it would be in >=1.0.debabc,>0.0.dev0
|
| 931 |
+
if not prereleases and item.is_prerelease:
|
| 932 |
+
return False
|
| 933 |
+
|
| 934 |
+
if installed and item.is_prerelease:
|
| 935 |
+
item = Version(item.base_version)
|
| 936 |
+
|
| 937 |
+
# We simply dispatch to the underlying specs here to make sure that the
|
| 938 |
+
# given version is contained within all of them.
|
| 939 |
+
# Note: This use of all() here means that an empty set of specifiers
|
| 940 |
+
# will always return True, this is an explicit design decision.
|
| 941 |
+
return all(s.contains(item, prereleases=prereleases) for s in self._specs)
|
| 942 |
+
|
| 943 |
+
def filter(
|
| 944 |
+
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
|
| 945 |
+
) -> Iterator[UnparsedVersionVar]:
|
| 946 |
+
"""Filter items in the given iterable, that match the specifiers in this set.
|
| 947 |
+
|
| 948 |
+
:param iterable:
|
| 949 |
+
An iterable that can contain version strings and :class:`Version` instances.
|
| 950 |
+
The items in the iterable will be filtered according to the specifier.
|
| 951 |
+
:param prereleases:
|
| 952 |
+
Whether or not to allow prereleases in the returned iterator. If set to
|
| 953 |
+
``None`` (the default), it will be intelligently decide whether to allow
|
| 954 |
+
prereleases or not (based on the :attr:`prereleases` attribute, and
|
| 955 |
+
whether the only versions matching are prereleases).
|
| 956 |
+
|
| 957 |
+
This method is smarter than just ``filter(SpecifierSet(...).contains, [...])``
|
| 958 |
+
because it implements the rule from :pep:`440` that a prerelease item
|
| 959 |
+
SHOULD be accepted if no other versions match the given specifier.
|
| 960 |
+
|
| 961 |
+
>>> list(SpecifierSet(">=1.2.3").filter(["1.2", "1.3", "1.5a1"]))
|
| 962 |
+
['1.3']
|
| 963 |
+
>>> list(SpecifierSet(">=1.2.3").filter(["1.2", "1.3", Version("1.4")]))
|
| 964 |
+
['1.3', <Version('1.4')>]
|
| 965 |
+
>>> list(SpecifierSet(">=1.2.3").filter(["1.2", "1.5a1"]))
|
| 966 |
+
[]
|
| 967 |
+
>>> list(SpecifierSet(">=1.2.3").filter(["1.3", "1.5a1"], prereleases=True))
|
| 968 |
+
['1.3', '1.5a1']
|
| 969 |
+
>>> list(SpecifierSet(">=1.2.3", prereleases=True).filter(["1.3", "1.5a1"]))
|
| 970 |
+
['1.3', '1.5a1']
|
| 971 |
+
|
| 972 |
+
An "empty" SpecifierSet will filter items based on the presence of prerelease
|
| 973 |
+
versions in the set.
|
| 974 |
+
|
| 975 |
+
>>> list(SpecifierSet("").filter(["1.3", "1.5a1"]))
|
| 976 |
+
['1.3']
|
| 977 |
+
>>> list(SpecifierSet("").filter(["1.5a1"]))
|
| 978 |
+
['1.5a1']
|
| 979 |
+
>>> list(SpecifierSet("", prereleases=True).filter(["1.3", "1.5a1"]))
|
| 980 |
+
['1.3', '1.5a1']
|
| 981 |
+
>>> list(SpecifierSet("").filter(["1.3", "1.5a1"], prereleases=True))
|
| 982 |
+
['1.3', '1.5a1']
|
| 983 |
+
"""
|
| 984 |
+
# Determine if we're forcing a prerelease or not, if we're not forcing
|
| 985 |
+
# one for this particular filter call, then we'll use whatever the
|
| 986 |
+
# SpecifierSet thinks for whether or not we should support prereleases.
|
| 987 |
+
if prereleases is None:
|
| 988 |
+
prereleases = self.prereleases
|
| 989 |
+
|
| 990 |
+
# If we have any specifiers, then we want to wrap our iterable in the
|
| 991 |
+
# filter method for each one, this will act as a logical AND amongst
|
| 992 |
+
# each specifier.
|
| 993 |
+
if self._specs:
|
| 994 |
+
for spec in self._specs:
|
| 995 |
+
iterable = spec.filter(iterable, prereleases=bool(prereleases))
|
| 996 |
+
return iter(iterable)
|
| 997 |
+
# If we do not have any specifiers, then we need to have a rough filter
|
| 998 |
+
# which will filter out any pre-releases, unless there are no final
|
| 999 |
+
# releases.
|
| 1000 |
+
else:
|
| 1001 |
+
filtered: list[UnparsedVersionVar] = []
|
| 1002 |
+
found_prereleases: list[UnparsedVersionVar] = []
|
| 1003 |
+
|
| 1004 |
+
for item in iterable:
|
| 1005 |
+
parsed_version = _coerce_version(item)
|
| 1006 |
+
|
| 1007 |
+
# Store any item which is a pre-release for later unless we've
|
| 1008 |
+
# already found a final version or we are accepting prereleases
|
| 1009 |
+
if parsed_version.is_prerelease and not prereleases:
|
| 1010 |
+
if not filtered:
|
| 1011 |
+
found_prereleases.append(item)
|
| 1012 |
+
else:
|
| 1013 |
+
filtered.append(item)
|
| 1014 |
+
|
| 1015 |
+
# If we've found no items except for pre-releases, then we'll go
|
| 1016 |
+
# ahead and use the pre-releases
|
| 1017 |
+
if not filtered and found_prereleases and prereleases is None:
|
| 1018 |
+
return iter(found_prereleases)
|
| 1019 |
+
|
| 1020 |
+
return iter(filtered)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/tags.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is dual licensed under the terms of the Apache License, Version
|
| 2 |
+
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
| 3 |
+
# for complete details.
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import platform
|
| 9 |
+
import re
|
| 10 |
+
import struct
|
| 11 |
+
import subprocess
|
| 12 |
+
import sys
|
| 13 |
+
import sysconfig
|
| 14 |
+
from importlib.machinery import EXTENSION_SUFFIXES
|
| 15 |
+
from typing import (
|
| 16 |
+
Iterable,
|
| 17 |
+
Iterator,
|
| 18 |
+
Sequence,
|
| 19 |
+
Tuple,
|
| 20 |
+
cast,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from . import _manylinux, _musllinux
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
PythonVersion = Sequence[int]
|
| 28 |
+
AppleVersion = Tuple[int, int]
|
| 29 |
+
|
| 30 |
+
INTERPRETER_SHORT_NAMES: dict[str, str] = {
|
| 31 |
+
"python": "py", # Generic.
|
| 32 |
+
"cpython": "cp",
|
| 33 |
+
"pypy": "pp",
|
| 34 |
+
"ironpython": "ip",
|
| 35 |
+
"jython": "jy",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
_32_BIT_INTERPRETER = struct.calcsize("P") == 4
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Tag:
|
| 43 |
+
"""
|
| 44 |
+
A representation of the tag triple for a wheel.
|
| 45 |
+
|
| 46 |
+
Instances are considered immutable and thus are hashable. Equality checking
|
| 47 |
+
is also supported.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
__slots__ = ["_abi", "_hash", "_interpreter", "_platform"]
|
| 51 |
+
|
| 52 |
+
def __init__(self, interpreter: str, abi: str, platform: str) -> None:
|
| 53 |
+
self._interpreter = interpreter.lower()
|
| 54 |
+
self._abi = abi.lower()
|
| 55 |
+
self._platform = platform.lower()
|
| 56 |
+
# The __hash__ of every single element in a Set[Tag] will be evaluated each time
|
| 57 |
+
# that a set calls its `.disjoint()` method, which may be called hundreds of
|
| 58 |
+
# times when scanning a page of links for packages with tags matching that
|
| 59 |
+
# Set[Tag]. Pre-computing the value here produces significant speedups for
|
| 60 |
+
# downstream consumers.
|
| 61 |
+
self._hash = hash((self._interpreter, self._abi, self._platform))
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def interpreter(self) -> str:
|
| 65 |
+
return self._interpreter
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def abi(self) -> str:
|
| 69 |
+
return self._abi
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def platform(self) -> str:
|
| 73 |
+
return self._platform
|
| 74 |
+
|
| 75 |
+
def __eq__(self, other: object) -> bool:
|
| 76 |
+
if not isinstance(other, Tag):
|
| 77 |
+
return NotImplemented
|
| 78 |
+
|
| 79 |
+
return (
|
| 80 |
+
(self._hash == other._hash) # Short-circuit ASAP for perf reasons.
|
| 81 |
+
and (self._platform == other._platform)
|
| 82 |
+
and (self._abi == other._abi)
|
| 83 |
+
and (self._interpreter == other._interpreter)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def __hash__(self) -> int:
|
| 87 |
+
return self._hash
|
| 88 |
+
|
| 89 |
+
def __str__(self) -> str:
|
| 90 |
+
return f"{self._interpreter}-{self._abi}-{self._platform}"
|
| 91 |
+
|
| 92 |
+
def __repr__(self) -> str:
|
| 93 |
+
return f"<{self} @ {id(self)}>"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def parse_tag(tag: str) -> frozenset[Tag]:
|
| 97 |
+
"""
|
| 98 |
+
Parses the provided tag (e.g. `py3-none-any`) into a frozenset of Tag instances.
|
| 99 |
+
|
| 100 |
+
Returning a set is required due to the possibility that the tag is a
|
| 101 |
+
compressed tag set.
|
| 102 |
+
"""
|
| 103 |
+
tags = set()
|
| 104 |
+
interpreters, abis, platforms = tag.split("-")
|
| 105 |
+
for interpreter in interpreters.split("."):
|
| 106 |
+
for abi in abis.split("."):
|
| 107 |
+
for platform_ in platforms.split("."):
|
| 108 |
+
tags.add(Tag(interpreter, abi, platform_))
|
| 109 |
+
return frozenset(tags)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _get_config_var(name: str, warn: bool = False) -> int | str | None:
|
| 113 |
+
value: int | str | None = sysconfig.get_config_var(name)
|
| 114 |
+
if value is None and warn:
|
| 115 |
+
logger.debug(
|
| 116 |
+
"Config variable '%s' is unset, Python ABI tag may be incorrect", name
|
| 117 |
+
)
|
| 118 |
+
return value
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _normalize_string(string: str) -> str:
|
| 122 |
+
return string.replace(".", "_").replace("-", "_").replace(" ", "_")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _is_threaded_cpython(abis: list[str]) -> bool:
|
| 126 |
+
"""
|
| 127 |
+
Determine if the ABI corresponds to a threaded (`--disable-gil`) build.
|
| 128 |
+
|
| 129 |
+
The threaded builds are indicated by a "t" in the abiflags.
|
| 130 |
+
"""
|
| 131 |
+
if len(abis) == 0:
|
| 132 |
+
return False
|
| 133 |
+
# expect e.g., cp313
|
| 134 |
+
m = re.match(r"cp\d+(.*)", abis[0])
|
| 135 |
+
if not m:
|
| 136 |
+
return False
|
| 137 |
+
abiflags = m.group(1)
|
| 138 |
+
return "t" in abiflags
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _abi3_applies(python_version: PythonVersion, threading: bool) -> bool:
|
| 142 |
+
"""
|
| 143 |
+
Determine if the Python version supports abi3.
|
| 144 |
+
|
| 145 |
+
PEP 384 was first implemented in Python 3.2. The threaded (`--disable-gil`)
|
| 146 |
+
builds do not support abi3.
|
| 147 |
+
"""
|
| 148 |
+
return len(python_version) > 1 and tuple(python_version) >= (3, 2) and not threading
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _cpython_abis(py_version: PythonVersion, warn: bool = False) -> list[str]:
|
| 152 |
+
py_version = tuple(py_version) # To allow for version comparison.
|
| 153 |
+
abis = []
|
| 154 |
+
version = _version_nodot(py_version[:2])
|
| 155 |
+
threading = debug = pymalloc = ucs4 = ""
|
| 156 |
+
with_debug = _get_config_var("Py_DEBUG", warn)
|
| 157 |
+
has_refcount = hasattr(sys, "gettotalrefcount")
|
| 158 |
+
# Windows doesn't set Py_DEBUG, so checking for support of debug-compiled
|
| 159 |
+
# extension modules is the best option.
|
| 160 |
+
# https://github.com/pypa/pip/issues/3383#issuecomment-173267692
|
| 161 |
+
has_ext = "_d.pyd" in EXTENSION_SUFFIXES
|
| 162 |
+
if with_debug or (with_debug is None and (has_refcount or has_ext)):
|
| 163 |
+
debug = "d"
|
| 164 |
+
if py_version >= (3, 13) and _get_config_var("Py_GIL_DISABLED", warn):
|
| 165 |
+
threading = "t"
|
| 166 |
+
if py_version < (3, 8):
|
| 167 |
+
with_pymalloc = _get_config_var("WITH_PYMALLOC", warn)
|
| 168 |
+
if with_pymalloc or with_pymalloc is None:
|
| 169 |
+
pymalloc = "m"
|
| 170 |
+
if py_version < (3, 3):
|
| 171 |
+
unicode_size = _get_config_var("Py_UNICODE_SIZE", warn)
|
| 172 |
+
if unicode_size == 4 or (
|
| 173 |
+
unicode_size is None and sys.maxunicode == 0x10FFFF
|
| 174 |
+
):
|
| 175 |
+
ucs4 = "u"
|
| 176 |
+
elif debug:
|
| 177 |
+
# Debug builds can also load "normal" extension modules.
|
| 178 |
+
# We can also assume no UCS-4 or pymalloc requirement.
|
| 179 |
+
abis.append(f"cp{version}{threading}")
|
| 180 |
+
abis.insert(0, f"cp{version}{threading}{debug}{pymalloc}{ucs4}")
|
| 181 |
+
return abis
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def cpython_tags(
|
| 185 |
+
python_version: PythonVersion | None = None,
|
| 186 |
+
abis: Iterable[str] | None = None,
|
| 187 |
+
platforms: Iterable[str] | None = None,
|
| 188 |
+
*,
|
| 189 |
+
warn: bool = False,
|
| 190 |
+
) -> Iterator[Tag]:
|
| 191 |
+
"""
|
| 192 |
+
Yields the tags for a CPython interpreter.
|
| 193 |
+
|
| 194 |
+
The tags consist of:
|
| 195 |
+
- cp<python_version>-<abi>-<platform>
|
| 196 |
+
- cp<python_version>-abi3-<platform>
|
| 197 |
+
- cp<python_version>-none-<platform>
|
| 198 |
+
- cp<less than python_version>-abi3-<platform> # Older Python versions down to 3.2.
|
| 199 |
+
|
| 200 |
+
If python_version only specifies a major version then user-provided ABIs and
|
| 201 |
+
the 'none' ABItag will be used.
|
| 202 |
+
|
| 203 |
+
If 'abi3' or 'none' are specified in 'abis' then they will be yielded at
|
| 204 |
+
their normal position and not at the beginning.
|
| 205 |
+
"""
|
| 206 |
+
if not python_version:
|
| 207 |
+
python_version = sys.version_info[:2]
|
| 208 |
+
|
| 209 |
+
interpreter = f"cp{_version_nodot(python_version[:2])}"
|
| 210 |
+
|
| 211 |
+
if abis is None:
|
| 212 |
+
if len(python_version) > 1:
|
| 213 |
+
abis = _cpython_abis(python_version, warn)
|
| 214 |
+
else:
|
| 215 |
+
abis = []
|
| 216 |
+
abis = list(abis)
|
| 217 |
+
# 'abi3' and 'none' are explicitly handled later.
|
| 218 |
+
for explicit_abi in ("abi3", "none"):
|
| 219 |
+
try:
|
| 220 |
+
abis.remove(explicit_abi)
|
| 221 |
+
except ValueError:
|
| 222 |
+
pass
|
| 223 |
+
|
| 224 |
+
platforms = list(platforms or platform_tags())
|
| 225 |
+
for abi in abis:
|
| 226 |
+
for platform_ in platforms:
|
| 227 |
+
yield Tag(interpreter, abi, platform_)
|
| 228 |
+
|
| 229 |
+
threading = _is_threaded_cpython(abis)
|
| 230 |
+
use_abi3 = _abi3_applies(python_version, threading)
|
| 231 |
+
if use_abi3:
|
| 232 |
+
yield from (Tag(interpreter, "abi3", platform_) for platform_ in platforms)
|
| 233 |
+
yield from (Tag(interpreter, "none", platform_) for platform_ in platforms)
|
| 234 |
+
|
| 235 |
+
if use_abi3:
|
| 236 |
+
for minor_version in range(python_version[1] - 1, 1, -1):
|
| 237 |
+
for platform_ in platforms:
|
| 238 |
+
version = _version_nodot((python_version[0], minor_version))
|
| 239 |
+
interpreter = f"cp{version}"
|
| 240 |
+
yield Tag(interpreter, "abi3", platform_)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _generic_abi() -> list[str]:
|
| 244 |
+
"""
|
| 245 |
+
Return the ABI tag based on EXT_SUFFIX.
|
| 246 |
+
"""
|
| 247 |
+
# The following are examples of `EXT_SUFFIX`.
|
| 248 |
+
# We want to keep the parts which are related to the ABI and remove the
|
| 249 |
+
# parts which are related to the platform:
|
| 250 |
+
# - linux: '.cpython-310-x86_64-linux-gnu.so' => cp310
|
| 251 |
+
# - mac: '.cpython-310-darwin.so' => cp310
|
| 252 |
+
# - win: '.cp310-win_amd64.pyd' => cp310
|
| 253 |
+
# - win: '.pyd' => cp37 (uses _cpython_abis())
|
| 254 |
+
# - pypy: '.pypy38-pp73-x86_64-linux-gnu.so' => pypy38_pp73
|
| 255 |
+
# - graalpy: '.graalpy-38-native-x86_64-darwin.dylib'
|
| 256 |
+
# => graalpy_38_native
|
| 257 |
+
|
| 258 |
+
ext_suffix = _get_config_var("EXT_SUFFIX", warn=True)
|
| 259 |
+
if not isinstance(ext_suffix, str) or ext_suffix[0] != ".":
|
| 260 |
+
raise SystemError("invalid sysconfig.get_config_var('EXT_SUFFIX')")
|
| 261 |
+
parts = ext_suffix.split(".")
|
| 262 |
+
if len(parts) < 3:
|
| 263 |
+
# CPython3.7 and earlier uses ".pyd" on Windows.
|
| 264 |
+
return _cpython_abis(sys.version_info[:2])
|
| 265 |
+
soabi = parts[1]
|
| 266 |
+
if soabi.startswith("cpython"):
|
| 267 |
+
# non-windows
|
| 268 |
+
abi = "cp" + soabi.split("-")[1]
|
| 269 |
+
elif soabi.startswith("cp"):
|
| 270 |
+
# windows
|
| 271 |
+
abi = soabi.split("-")[0]
|
| 272 |
+
elif soabi.startswith("pypy"):
|
| 273 |
+
abi = "-".join(soabi.split("-")[:2])
|
| 274 |
+
elif soabi.startswith("graalpy"):
|
| 275 |
+
abi = "-".join(soabi.split("-")[:3])
|
| 276 |
+
elif soabi:
|
| 277 |
+
# pyston, ironpython, others?
|
| 278 |
+
abi = soabi
|
| 279 |
+
else:
|
| 280 |
+
return []
|
| 281 |
+
return [_normalize_string(abi)]
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def generic_tags(
|
| 285 |
+
interpreter: str | None = None,
|
| 286 |
+
abis: Iterable[str] | None = None,
|
| 287 |
+
platforms: Iterable[str] | None = None,
|
| 288 |
+
*,
|
| 289 |
+
warn: bool = False,
|
| 290 |
+
) -> Iterator[Tag]:
|
| 291 |
+
"""
|
| 292 |
+
Yields the tags for a generic interpreter.
|
| 293 |
+
|
| 294 |
+
The tags consist of:
|
| 295 |
+
- <interpreter>-<abi>-<platform>
|
| 296 |
+
|
| 297 |
+
The "none" ABI will be added if it was not explicitly provided.
|
| 298 |
+
"""
|
| 299 |
+
if not interpreter:
|
| 300 |
+
interp_name = interpreter_name()
|
| 301 |
+
interp_version = interpreter_version(warn=warn)
|
| 302 |
+
interpreter = "".join([interp_name, interp_version])
|
| 303 |
+
if abis is None:
|
| 304 |
+
abis = _generic_abi()
|
| 305 |
+
else:
|
| 306 |
+
abis = list(abis)
|
| 307 |
+
platforms = list(platforms or platform_tags())
|
| 308 |
+
if "none" not in abis:
|
| 309 |
+
abis.append("none")
|
| 310 |
+
for abi in abis:
|
| 311 |
+
for platform_ in platforms:
|
| 312 |
+
yield Tag(interpreter, abi, platform_)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def _py_interpreter_range(py_version: PythonVersion) -> Iterator[str]:
|
| 316 |
+
"""
|
| 317 |
+
Yields Python versions in descending order.
|
| 318 |
+
|
| 319 |
+
After the latest version, the major-only version will be yielded, and then
|
| 320 |
+
all previous versions of that major version.
|
| 321 |
+
"""
|
| 322 |
+
if len(py_version) > 1:
|
| 323 |
+
yield f"py{_version_nodot(py_version[:2])}"
|
| 324 |
+
yield f"py{py_version[0]}"
|
| 325 |
+
if len(py_version) > 1:
|
| 326 |
+
for minor in range(py_version[1] - 1, -1, -1):
|
| 327 |
+
yield f"py{_version_nodot((py_version[0], minor))}"
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def compatible_tags(
|
| 331 |
+
python_version: PythonVersion | None = None,
|
| 332 |
+
interpreter: str | None = None,
|
| 333 |
+
platforms: Iterable[str] | None = None,
|
| 334 |
+
) -> Iterator[Tag]:
|
| 335 |
+
"""
|
| 336 |
+
Yields the sequence of tags that are compatible with a specific version of Python.
|
| 337 |
+
|
| 338 |
+
The tags consist of:
|
| 339 |
+
- py*-none-<platform>
|
| 340 |
+
- <interpreter>-none-any # ... if `interpreter` is provided.
|
| 341 |
+
- py*-none-any
|
| 342 |
+
"""
|
| 343 |
+
if not python_version:
|
| 344 |
+
python_version = sys.version_info[:2]
|
| 345 |
+
platforms = list(platforms or platform_tags())
|
| 346 |
+
for version in _py_interpreter_range(python_version):
|
| 347 |
+
for platform_ in platforms:
|
| 348 |
+
yield Tag(version, "none", platform_)
|
| 349 |
+
if interpreter:
|
| 350 |
+
yield Tag(interpreter, "none", "any")
|
| 351 |
+
for version in _py_interpreter_range(python_version):
|
| 352 |
+
yield Tag(version, "none", "any")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _mac_arch(arch: str, is_32bit: bool = _32_BIT_INTERPRETER) -> str:
|
| 356 |
+
if not is_32bit:
|
| 357 |
+
return arch
|
| 358 |
+
|
| 359 |
+
if arch.startswith("ppc"):
|
| 360 |
+
return "ppc"
|
| 361 |
+
|
| 362 |
+
return "i386"
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _mac_binary_formats(version: AppleVersion, cpu_arch: str) -> list[str]:
|
| 366 |
+
formats = [cpu_arch]
|
| 367 |
+
if cpu_arch == "x86_64":
|
| 368 |
+
if version < (10, 4):
|
| 369 |
+
return []
|
| 370 |
+
formats.extend(["intel", "fat64", "fat32"])
|
| 371 |
+
|
| 372 |
+
elif cpu_arch == "i386":
|
| 373 |
+
if version < (10, 4):
|
| 374 |
+
return []
|
| 375 |
+
formats.extend(["intel", "fat32", "fat"])
|
| 376 |
+
|
| 377 |
+
elif cpu_arch == "ppc64":
|
| 378 |
+
# TODO: Need to care about 32-bit PPC for ppc64 through 10.2?
|
| 379 |
+
if version > (10, 5) or version < (10, 4):
|
| 380 |
+
return []
|
| 381 |
+
formats.append("fat64")
|
| 382 |
+
|
| 383 |
+
elif cpu_arch == "ppc":
|
| 384 |
+
if version > (10, 6):
|
| 385 |
+
return []
|
| 386 |
+
formats.extend(["fat32", "fat"])
|
| 387 |
+
|
| 388 |
+
if cpu_arch in {"arm64", "x86_64"}:
|
| 389 |
+
formats.append("universal2")
|
| 390 |
+
|
| 391 |
+
if cpu_arch in {"x86_64", "i386", "ppc64", "ppc", "intel"}:
|
| 392 |
+
formats.append("universal")
|
| 393 |
+
|
| 394 |
+
return formats
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def mac_platforms(
|
| 398 |
+
version: AppleVersion | None = None, arch: str | None = None
|
| 399 |
+
) -> Iterator[str]:
|
| 400 |
+
"""
|
| 401 |
+
Yields the platform tags for a macOS system.
|
| 402 |
+
|
| 403 |
+
The `version` parameter is a two-item tuple specifying the macOS version to
|
| 404 |
+
generate platform tags for. The `arch` parameter is the CPU architecture to
|
| 405 |
+
generate platform tags for. Both parameters default to the appropriate value
|
| 406 |
+
for the current system.
|
| 407 |
+
"""
|
| 408 |
+
version_str, _, cpu_arch = platform.mac_ver()
|
| 409 |
+
if version is None:
|
| 410 |
+
version = cast("AppleVersion", tuple(map(int, version_str.split(".")[:2])))
|
| 411 |
+
if version == (10, 16):
|
| 412 |
+
# When built against an older macOS SDK, Python will report macOS 10.16
|
| 413 |
+
# instead of the real version.
|
| 414 |
+
version_str = subprocess.run(
|
| 415 |
+
[
|
| 416 |
+
sys.executable,
|
| 417 |
+
"-sS",
|
| 418 |
+
"-c",
|
| 419 |
+
"import platform; print(platform.mac_ver()[0])",
|
| 420 |
+
],
|
| 421 |
+
check=True,
|
| 422 |
+
env={"SYSTEM_VERSION_COMPAT": "0"},
|
| 423 |
+
stdout=subprocess.PIPE,
|
| 424 |
+
text=True,
|
| 425 |
+
).stdout
|
| 426 |
+
version = cast("AppleVersion", tuple(map(int, version_str.split(".")[:2])))
|
| 427 |
+
else:
|
| 428 |
+
version = version
|
| 429 |
+
if arch is None:
|
| 430 |
+
arch = _mac_arch(cpu_arch)
|
| 431 |
+
else:
|
| 432 |
+
arch = arch
|
| 433 |
+
|
| 434 |
+
if (10, 0) <= version and version < (11, 0):
|
| 435 |
+
# Prior to Mac OS 11, each yearly release of Mac OS bumped the
|
| 436 |
+
# "minor" version number. The major version was always 10.
|
| 437 |
+
major_version = 10
|
| 438 |
+
for minor_version in range(version[1], -1, -1):
|
| 439 |
+
compat_version = major_version, minor_version
|
| 440 |
+
binary_formats = _mac_binary_formats(compat_version, arch)
|
| 441 |
+
for binary_format in binary_formats:
|
| 442 |
+
yield f"macosx_{major_version}_{minor_version}_{binary_format}"
|
| 443 |
+
|
| 444 |
+
if version >= (11, 0):
|
| 445 |
+
# Starting with Mac OS 11, each yearly release bumps the major version
|
| 446 |
+
# number. The minor versions are now the midyear updates.
|
| 447 |
+
minor_version = 0
|
| 448 |
+
for major_version in range(version[0], 10, -1):
|
| 449 |
+
compat_version = major_version, minor_version
|
| 450 |
+
binary_formats = _mac_binary_formats(compat_version, arch)
|
| 451 |
+
for binary_format in binary_formats:
|
| 452 |
+
yield f"macosx_{major_version}_{minor_version}_{binary_format}"
|
| 453 |
+
|
| 454 |
+
if version >= (11, 0):
|
| 455 |
+
# Mac OS 11 on x86_64 is compatible with binaries from previous releases.
|
| 456 |
+
# Arm64 support was introduced in 11.0, so no Arm binaries from previous
|
| 457 |
+
# releases exist.
|
| 458 |
+
#
|
| 459 |
+
# However, the "universal2" binary format can have a
|
| 460 |
+
# macOS version earlier than 11.0 when the x86_64 part of the binary supports
|
| 461 |
+
# that version of macOS.
|
| 462 |
+
major_version = 10
|
| 463 |
+
if arch == "x86_64":
|
| 464 |
+
for minor_version in range(16, 3, -1):
|
| 465 |
+
compat_version = major_version, minor_version
|
| 466 |
+
binary_formats = _mac_binary_formats(compat_version, arch)
|
| 467 |
+
for binary_format in binary_formats:
|
| 468 |
+
yield f"macosx_{major_version}_{minor_version}_{binary_format}"
|
| 469 |
+
else:
|
| 470 |
+
for minor_version in range(16, 3, -1):
|
| 471 |
+
compat_version = major_version, minor_version
|
| 472 |
+
binary_format = "universal2"
|
| 473 |
+
yield f"macosx_{major_version}_{minor_version}_{binary_format}"
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def ios_platforms(
|
| 477 |
+
version: AppleVersion | None = None, multiarch: str | None = None
|
| 478 |
+
) -> Iterator[str]:
|
| 479 |
+
"""
|
| 480 |
+
Yields the platform tags for an iOS system.
|
| 481 |
+
|
| 482 |
+
:param version: A two-item tuple specifying the iOS version to generate
|
| 483 |
+
platform tags for. Defaults to the current iOS version.
|
| 484 |
+
:param multiarch: The CPU architecture+ABI to generate platform tags for -
|
| 485 |
+
(the value used by `sys.implementation._multiarch` e.g.,
|
| 486 |
+
`arm64_iphoneos` or `x84_64_iphonesimulator`). Defaults to the current
|
| 487 |
+
multiarch value.
|
| 488 |
+
"""
|
| 489 |
+
if version is None:
|
| 490 |
+
# if iOS is the current platform, ios_ver *must* be defined. However,
|
| 491 |
+
# it won't exist for CPython versions before 3.13, which causes a mypy
|
| 492 |
+
# error.
|
| 493 |
+
_, release, _, _ = platform.ios_ver() # type: ignore[attr-defined, unused-ignore]
|
| 494 |
+
version = cast("AppleVersion", tuple(map(int, release.split(".")[:2])))
|
| 495 |
+
|
| 496 |
+
if multiarch is None:
|
| 497 |
+
multiarch = sys.implementation._multiarch
|
| 498 |
+
multiarch = multiarch.replace("-", "_")
|
| 499 |
+
|
| 500 |
+
ios_platform_template = "ios_{major}_{minor}_{multiarch}"
|
| 501 |
+
|
| 502 |
+
# Consider any iOS major.minor version from the version requested, down to
|
| 503 |
+
# 12.0. 12.0 is the first iOS version that is known to have enough features
|
| 504 |
+
# to support CPython. Consider every possible minor release up to X.9. There
|
| 505 |
+
# highest the minor has ever gone is 8 (14.8 and 15.8) but having some extra
|
| 506 |
+
# candidates that won't ever match doesn't really hurt, and it saves us from
|
| 507 |
+
# having to keep an explicit list of known iOS versions in the code. Return
|
| 508 |
+
# the results descending order of version number.
|
| 509 |
+
|
| 510 |
+
# If the requested major version is less than 12, there won't be any matches.
|
| 511 |
+
if version[0] < 12:
|
| 512 |
+
return
|
| 513 |
+
|
| 514 |
+
# Consider the actual X.Y version that was requested.
|
| 515 |
+
yield ios_platform_template.format(
|
| 516 |
+
major=version[0], minor=version[1], multiarch=multiarch
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# Consider every minor version from X.0 to the minor version prior to the
|
| 520 |
+
# version requested by the platform.
|
| 521 |
+
for minor in range(version[1] - 1, -1, -1):
|
| 522 |
+
yield ios_platform_template.format(
|
| 523 |
+
major=version[0], minor=minor, multiarch=multiarch
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
for major in range(version[0] - 1, 11, -1):
|
| 527 |
+
for minor in range(9, -1, -1):
|
| 528 |
+
yield ios_platform_template.format(
|
| 529 |
+
major=major, minor=minor, multiarch=multiarch
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def _linux_platforms(is_32bit: bool = _32_BIT_INTERPRETER) -> Iterator[str]:
|
| 534 |
+
linux = _normalize_string(sysconfig.get_platform())
|
| 535 |
+
if not linux.startswith("linux_"):
|
| 536 |
+
# we should never be here, just yield the sysconfig one and return
|
| 537 |
+
yield linux
|
| 538 |
+
return
|
| 539 |
+
if is_32bit:
|
| 540 |
+
if linux == "linux_x86_64":
|
| 541 |
+
linux = "linux_i686"
|
| 542 |
+
elif linux == "linux_aarch64":
|
| 543 |
+
linux = "linux_armv8l"
|
| 544 |
+
_, arch = linux.split("_", 1)
|
| 545 |
+
archs = {"armv8l": ["armv8l", "armv7l"]}.get(arch, [arch])
|
| 546 |
+
yield from _manylinux.platform_tags(archs)
|
| 547 |
+
yield from _musllinux.platform_tags(archs)
|
| 548 |
+
for arch in archs:
|
| 549 |
+
yield f"linux_{arch}"
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def _generic_platforms() -> Iterator[str]:
|
| 553 |
+
yield _normalize_string(sysconfig.get_platform())
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def platform_tags() -> Iterator[str]:
|
| 557 |
+
"""
|
| 558 |
+
Provides the platform tags for this installation.
|
| 559 |
+
"""
|
| 560 |
+
if platform.system() == "Darwin":
|
| 561 |
+
return mac_platforms()
|
| 562 |
+
elif platform.system() == "iOS":
|
| 563 |
+
return ios_platforms()
|
| 564 |
+
elif platform.system() == "Linux":
|
| 565 |
+
return _linux_platforms()
|
| 566 |
+
else:
|
| 567 |
+
return _generic_platforms()
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def interpreter_name() -> str:
|
| 571 |
+
"""
|
| 572 |
+
Returns the name of the running interpreter.
|
| 573 |
+
|
| 574 |
+
Some implementations have a reserved, two-letter abbreviation which will
|
| 575 |
+
be returned when appropriate.
|
| 576 |
+
"""
|
| 577 |
+
name = sys.implementation.name
|
| 578 |
+
return INTERPRETER_SHORT_NAMES.get(name) or name
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def interpreter_version(*, warn: bool = False) -> str:
|
| 582 |
+
"""
|
| 583 |
+
Returns the version of the running interpreter.
|
| 584 |
+
"""
|
| 585 |
+
version = _get_config_var("py_version_nodot", warn=warn)
|
| 586 |
+
if version:
|
| 587 |
+
version = str(version)
|
| 588 |
+
else:
|
| 589 |
+
version = _version_nodot(sys.version_info[:2])
|
| 590 |
+
return version
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def _version_nodot(version: PythonVersion) -> str:
|
| 594 |
+
return "".join(map(str, version))
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def sys_tags(*, warn: bool = False) -> Iterator[Tag]:
|
| 598 |
+
"""
|
| 599 |
+
Returns the sequence of tag triples for the running interpreter.
|
| 600 |
+
|
| 601 |
+
The order of the sequence corresponds to priority order for the
|
| 602 |
+
interpreter, from most to least important.
|
| 603 |
+
"""
|
| 604 |
+
|
| 605 |
+
interp_name = interpreter_name()
|
| 606 |
+
if interp_name == "cp":
|
| 607 |
+
yield from cpython_tags(warn=warn)
|
| 608 |
+
else:
|
| 609 |
+
yield from generic_tags()
|
| 610 |
+
|
| 611 |
+
if interp_name == "pp":
|
| 612 |
+
interp = "pp3"
|
| 613 |
+
elif interp_name == "cp":
|
| 614 |
+
interp = "cp" + interpreter_version(warn=warn)
|
| 615 |
+
else:
|
| 616 |
+
interp = None
|
| 617 |
+
yield from compatible_tags(interpreter=interp)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/packaging/version.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is dual licensed under the terms of the Apache License, Version
|
| 2 |
+
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
| 3 |
+
# for complete details.
|
| 4 |
+
"""
|
| 5 |
+
.. testsetup::
|
| 6 |
+
|
| 7 |
+
from packaging.version import parse, Version
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import itertools
|
| 13 |
+
import re
|
| 14 |
+
from typing import Any, Callable, NamedTuple, SupportsInt, Tuple, Union
|
| 15 |
+
|
| 16 |
+
from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType
|
| 17 |
+
|
| 18 |
+
__all__ = ["VERSION_PATTERN", "InvalidVersion", "Version", "parse"]
|
| 19 |
+
|
| 20 |
+
LocalType = Tuple[Union[int, str], ...]
|
| 21 |
+
|
| 22 |
+
CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]]
|
| 23 |
+
CmpLocalType = Union[
|
| 24 |
+
NegativeInfinityType,
|
| 25 |
+
Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...],
|
| 26 |
+
]
|
| 27 |
+
CmpKey = Tuple[
|
| 28 |
+
int,
|
| 29 |
+
Tuple[int, ...],
|
| 30 |
+
CmpPrePostDevType,
|
| 31 |
+
CmpPrePostDevType,
|
| 32 |
+
CmpPrePostDevType,
|
| 33 |
+
CmpLocalType,
|
| 34 |
+
]
|
| 35 |
+
VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class _Version(NamedTuple):
|
| 39 |
+
epoch: int
|
| 40 |
+
release: tuple[int, ...]
|
| 41 |
+
dev: tuple[str, int] | None
|
| 42 |
+
pre: tuple[str, int] | None
|
| 43 |
+
post: tuple[str, int] | None
|
| 44 |
+
local: LocalType | None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def parse(version: str) -> Version:
|
| 48 |
+
"""Parse the given version string.
|
| 49 |
+
|
| 50 |
+
>>> parse('1.0.dev1')
|
| 51 |
+
<Version('1.0.dev1')>
|
| 52 |
+
|
| 53 |
+
:param version: The version string to parse.
|
| 54 |
+
:raises InvalidVersion: When the version string is not a valid version.
|
| 55 |
+
"""
|
| 56 |
+
return Version(version)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class InvalidVersion(ValueError):
|
| 60 |
+
"""Raised when a version string is not a valid version.
|
| 61 |
+
|
| 62 |
+
>>> Version("invalid")
|
| 63 |
+
Traceback (most recent call last):
|
| 64 |
+
...
|
| 65 |
+
packaging.version.InvalidVersion: Invalid version: 'invalid'
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class _BaseVersion:
|
| 70 |
+
_key: tuple[Any, ...]
|
| 71 |
+
|
| 72 |
+
def __hash__(self) -> int:
|
| 73 |
+
return hash(self._key)
|
| 74 |
+
|
| 75 |
+
# Please keep the duplicated `isinstance` check
|
| 76 |
+
# in the six comparisons hereunder
|
| 77 |
+
# unless you find a way to avoid adding overhead function calls.
|
| 78 |
+
def __lt__(self, other: _BaseVersion) -> bool:
|
| 79 |
+
if not isinstance(other, _BaseVersion):
|
| 80 |
+
return NotImplemented
|
| 81 |
+
|
| 82 |
+
return self._key < other._key
|
| 83 |
+
|
| 84 |
+
def __le__(self, other: _BaseVersion) -> bool:
|
| 85 |
+
if not isinstance(other, _BaseVersion):
|
| 86 |
+
return NotImplemented
|
| 87 |
+
|
| 88 |
+
return self._key <= other._key
|
| 89 |
+
|
| 90 |
+
def __eq__(self, other: object) -> bool:
|
| 91 |
+
if not isinstance(other, _BaseVersion):
|
| 92 |
+
return NotImplemented
|
| 93 |
+
|
| 94 |
+
return self._key == other._key
|
| 95 |
+
|
| 96 |
+
def __ge__(self, other: _BaseVersion) -> bool:
|
| 97 |
+
if not isinstance(other, _BaseVersion):
|
| 98 |
+
return NotImplemented
|
| 99 |
+
|
| 100 |
+
return self._key >= other._key
|
| 101 |
+
|
| 102 |
+
def __gt__(self, other: _BaseVersion) -> bool:
|
| 103 |
+
if not isinstance(other, _BaseVersion):
|
| 104 |
+
return NotImplemented
|
| 105 |
+
|
| 106 |
+
return self._key > other._key
|
| 107 |
+
|
| 108 |
+
def __ne__(self, other: object) -> bool:
|
| 109 |
+
if not isinstance(other, _BaseVersion):
|
| 110 |
+
return NotImplemented
|
| 111 |
+
|
| 112 |
+
return self._key != other._key
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Deliberately not anchored to the start and end of the string, to make it
|
| 116 |
+
# easier for 3rd party code to reuse
|
| 117 |
+
_VERSION_PATTERN = r"""
|
| 118 |
+
v?
|
| 119 |
+
(?:
|
| 120 |
+
(?:(?P<epoch>[0-9]+)!)? # epoch
|
| 121 |
+
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
| 122 |
+
(?P<pre> # pre-release
|
| 123 |
+
[-_\.]?
|
| 124 |
+
(?P<pre_l>alpha|a|beta|b|preview|pre|c|rc)
|
| 125 |
+
[-_\.]?
|
| 126 |
+
(?P<pre_n>[0-9]+)?
|
| 127 |
+
)?
|
| 128 |
+
(?P<post> # post release
|
| 129 |
+
(?:-(?P<post_n1>[0-9]+))
|
| 130 |
+
|
|
| 131 |
+
(?:
|
| 132 |
+
[-_\.]?
|
| 133 |
+
(?P<post_l>post|rev|r)
|
| 134 |
+
[-_\.]?
|
| 135 |
+
(?P<post_n2>[0-9]+)?
|
| 136 |
+
)
|
| 137 |
+
)?
|
| 138 |
+
(?P<dev> # dev release
|
| 139 |
+
[-_\.]?
|
| 140 |
+
(?P<dev_l>dev)
|
| 141 |
+
[-_\.]?
|
| 142 |
+
(?P<dev_n>[0-9]+)?
|
| 143 |
+
)?
|
| 144 |
+
)
|
| 145 |
+
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
VERSION_PATTERN = _VERSION_PATTERN
|
| 149 |
+
"""
|
| 150 |
+
A string containing the regular expression used to match a valid version.
|
| 151 |
+
|
| 152 |
+
The pattern is not anchored at either end, and is intended for embedding in larger
|
| 153 |
+
expressions (for example, matching a version number as part of a file name). The
|
| 154 |
+
regular expression should be compiled with the ``re.VERBOSE`` and ``re.IGNORECASE``
|
| 155 |
+
flags set.
|
| 156 |
+
|
| 157 |
+
:meta hide-value:
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Version(_BaseVersion):
|
| 162 |
+
"""This class abstracts handling of a project's versions.
|
| 163 |
+
|
| 164 |
+
A :class:`Version` instance is comparison aware and can be compared and
|
| 165 |
+
sorted using the standard Python interfaces.
|
| 166 |
+
|
| 167 |
+
>>> v1 = Version("1.0a5")
|
| 168 |
+
>>> v2 = Version("1.0")
|
| 169 |
+
>>> v1
|
| 170 |
+
<Version('1.0a5')>
|
| 171 |
+
>>> v2
|
| 172 |
+
<Version('1.0')>
|
| 173 |
+
>>> v1 < v2
|
| 174 |
+
True
|
| 175 |
+
>>> v1 == v2
|
| 176 |
+
False
|
| 177 |
+
>>> v1 > v2
|
| 178 |
+
False
|
| 179 |
+
>>> v1 >= v2
|
| 180 |
+
False
|
| 181 |
+
>>> v1 <= v2
|
| 182 |
+
True
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
_regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
|
| 186 |
+
_key: CmpKey
|
| 187 |
+
|
| 188 |
+
def __init__(self, version: str) -> None:
|
| 189 |
+
"""Initialize a Version object.
|
| 190 |
+
|
| 191 |
+
:param version:
|
| 192 |
+
The string representation of a version which will be parsed and normalized
|
| 193 |
+
before use.
|
| 194 |
+
:raises InvalidVersion:
|
| 195 |
+
If the ``version`` does not conform to PEP 440 in any way then this
|
| 196 |
+
exception will be raised.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
# Validate the version and parse it into pieces
|
| 200 |
+
match = self._regex.search(version)
|
| 201 |
+
if not match:
|
| 202 |
+
raise InvalidVersion(f"Invalid version: {version!r}")
|
| 203 |
+
|
| 204 |
+
# Store the parsed out pieces of the version
|
| 205 |
+
self._version = _Version(
|
| 206 |
+
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
|
| 207 |
+
release=tuple(int(i) for i in match.group("release").split(".")),
|
| 208 |
+
pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
|
| 209 |
+
post=_parse_letter_version(
|
| 210 |
+
match.group("post_l"), match.group("post_n1") or match.group("post_n2")
|
| 211 |
+
),
|
| 212 |
+
dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
|
| 213 |
+
local=_parse_local_version(match.group("local")),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Generate a key which will be used for sorting
|
| 217 |
+
self._key = _cmpkey(
|
| 218 |
+
self._version.epoch,
|
| 219 |
+
self._version.release,
|
| 220 |
+
self._version.pre,
|
| 221 |
+
self._version.post,
|
| 222 |
+
self._version.dev,
|
| 223 |
+
self._version.local,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
def __repr__(self) -> str:
|
| 227 |
+
"""A representation of the Version that shows all internal state.
|
| 228 |
+
|
| 229 |
+
>>> Version('1.0.0')
|
| 230 |
+
<Version('1.0.0')>
|
| 231 |
+
"""
|
| 232 |
+
return f"<Version('{self}')>"
|
| 233 |
+
|
| 234 |
+
def __str__(self) -> str:
|
| 235 |
+
"""A string representation of the version that can be round-tripped.
|
| 236 |
+
|
| 237 |
+
>>> str(Version("1.0a5"))
|
| 238 |
+
'1.0a5'
|
| 239 |
+
"""
|
| 240 |
+
parts = []
|
| 241 |
+
|
| 242 |
+
# Epoch
|
| 243 |
+
if self.epoch != 0:
|
| 244 |
+
parts.append(f"{self.epoch}!")
|
| 245 |
+
|
| 246 |
+
# Release segment
|
| 247 |
+
parts.append(".".join(str(x) for x in self.release))
|
| 248 |
+
|
| 249 |
+
# Pre-release
|
| 250 |
+
if self.pre is not None:
|
| 251 |
+
parts.append("".join(str(x) for x in self.pre))
|
| 252 |
+
|
| 253 |
+
# Post-release
|
| 254 |
+
if self.post is not None:
|
| 255 |
+
parts.append(f".post{self.post}")
|
| 256 |
+
|
| 257 |
+
# Development release
|
| 258 |
+
if self.dev is not None:
|
| 259 |
+
parts.append(f".dev{self.dev}")
|
| 260 |
+
|
| 261 |
+
# Local version segment
|
| 262 |
+
if self.local is not None:
|
| 263 |
+
parts.append(f"+{self.local}")
|
| 264 |
+
|
| 265 |
+
return "".join(parts)
|
| 266 |
+
|
| 267 |
+
@property
|
| 268 |
+
def epoch(self) -> int:
|
| 269 |
+
"""The epoch of the version.
|
| 270 |
+
|
| 271 |
+
>>> Version("2.0.0").epoch
|
| 272 |
+
0
|
| 273 |
+
>>> Version("1!2.0.0").epoch
|
| 274 |
+
1
|
| 275 |
+
"""
|
| 276 |
+
return self._version.epoch
|
| 277 |
+
|
| 278 |
+
@property
|
| 279 |
+
def release(self) -> tuple[int, ...]:
|
| 280 |
+
"""The components of the "release" segment of the version.
|
| 281 |
+
|
| 282 |
+
>>> Version("1.2.3").release
|
| 283 |
+
(1, 2, 3)
|
| 284 |
+
>>> Version("2.0.0").release
|
| 285 |
+
(2, 0, 0)
|
| 286 |
+
>>> Version("1!2.0.0.post0").release
|
| 287 |
+
(2, 0, 0)
|
| 288 |
+
|
| 289 |
+
Includes trailing zeroes but not the epoch or any pre-release / development /
|
| 290 |
+
post-release suffixes.
|
| 291 |
+
"""
|
| 292 |
+
return self._version.release
|
| 293 |
+
|
| 294 |
+
@property
|
| 295 |
+
def pre(self) -> tuple[str, int] | None:
|
| 296 |
+
"""The pre-release segment of the version.
|
| 297 |
+
|
| 298 |
+
>>> print(Version("1.2.3").pre)
|
| 299 |
+
None
|
| 300 |
+
>>> Version("1.2.3a1").pre
|
| 301 |
+
('a', 1)
|
| 302 |
+
>>> Version("1.2.3b1").pre
|
| 303 |
+
('b', 1)
|
| 304 |
+
>>> Version("1.2.3rc1").pre
|
| 305 |
+
('rc', 1)
|
| 306 |
+
"""
|
| 307 |
+
return self._version.pre
|
| 308 |
+
|
| 309 |
+
@property
|
| 310 |
+
def post(self) -> int | None:
|
| 311 |
+
"""The post-release number of the version.
|
| 312 |
+
|
| 313 |
+
>>> print(Version("1.2.3").post)
|
| 314 |
+
None
|
| 315 |
+
>>> Version("1.2.3.post1").post
|
| 316 |
+
1
|
| 317 |
+
"""
|
| 318 |
+
return self._version.post[1] if self._version.post else None
|
| 319 |
+
|
| 320 |
+
@property
|
| 321 |
+
def dev(self) -> int | None:
|
| 322 |
+
"""The development number of the version.
|
| 323 |
+
|
| 324 |
+
>>> print(Version("1.2.3").dev)
|
| 325 |
+
None
|
| 326 |
+
>>> Version("1.2.3.dev1").dev
|
| 327 |
+
1
|
| 328 |
+
"""
|
| 329 |
+
return self._version.dev[1] if self._version.dev else None
|
| 330 |
+
|
| 331 |
+
@property
|
| 332 |
+
def local(self) -> str | None:
|
| 333 |
+
"""The local version segment of the version.
|
| 334 |
+
|
| 335 |
+
>>> print(Version("1.2.3").local)
|
| 336 |
+
None
|
| 337 |
+
>>> Version("1.2.3+abc").local
|
| 338 |
+
'abc'
|
| 339 |
+
"""
|
| 340 |
+
if self._version.local:
|
| 341 |
+
return ".".join(str(x) for x in self._version.local)
|
| 342 |
+
else:
|
| 343 |
+
return None
|
| 344 |
+
|
| 345 |
+
@property
|
| 346 |
+
def public(self) -> str:
|
| 347 |
+
"""The public portion of the version.
|
| 348 |
+
|
| 349 |
+
>>> Version("1.2.3").public
|
| 350 |
+
'1.2.3'
|
| 351 |
+
>>> Version("1.2.3+abc").public
|
| 352 |
+
'1.2.3'
|
| 353 |
+
>>> Version("1!1.2.3dev1+abc").public
|
| 354 |
+
'1!1.2.3.dev1'
|
| 355 |
+
"""
|
| 356 |
+
return str(self).split("+", 1)[0]
|
| 357 |
+
|
| 358 |
+
@property
|
| 359 |
+
def base_version(self) -> str:
|
| 360 |
+
"""The "base version" of the version.
|
| 361 |
+
|
| 362 |
+
>>> Version("1.2.3").base_version
|
| 363 |
+
'1.2.3'
|
| 364 |
+
>>> Version("1.2.3+abc").base_version
|
| 365 |
+
'1.2.3'
|
| 366 |
+
>>> Version("1!1.2.3dev1+abc").base_version
|
| 367 |
+
'1!1.2.3'
|
| 368 |
+
|
| 369 |
+
The "base version" is the public version of the project without any pre or post
|
| 370 |
+
release markers.
|
| 371 |
+
"""
|
| 372 |
+
parts = []
|
| 373 |
+
|
| 374 |
+
# Epoch
|
| 375 |
+
if self.epoch != 0:
|
| 376 |
+
parts.append(f"{self.epoch}!")
|
| 377 |
+
|
| 378 |
+
# Release segment
|
| 379 |
+
parts.append(".".join(str(x) for x in self.release))
|
| 380 |
+
|
| 381 |
+
return "".join(parts)
|
| 382 |
+
|
| 383 |
+
@property
|
| 384 |
+
def is_prerelease(self) -> bool:
|
| 385 |
+
"""Whether this version is a pre-release.
|
| 386 |
+
|
| 387 |
+
>>> Version("1.2.3").is_prerelease
|
| 388 |
+
False
|
| 389 |
+
>>> Version("1.2.3a1").is_prerelease
|
| 390 |
+
True
|
| 391 |
+
>>> Version("1.2.3b1").is_prerelease
|
| 392 |
+
True
|
| 393 |
+
>>> Version("1.2.3rc1").is_prerelease
|
| 394 |
+
True
|
| 395 |
+
>>> Version("1.2.3dev1").is_prerelease
|
| 396 |
+
True
|
| 397 |
+
"""
|
| 398 |
+
return self.dev is not None or self.pre is not None
|
| 399 |
+
|
| 400 |
+
@property
|
| 401 |
+
def is_postrelease(self) -> bool:
|
| 402 |
+
"""Whether this version is a post-release.
|
| 403 |
+
|
| 404 |
+
>>> Version("1.2.3").is_postrelease
|
| 405 |
+
False
|
| 406 |
+
>>> Version("1.2.3.post1").is_postrelease
|
| 407 |
+
True
|
| 408 |
+
"""
|
| 409 |
+
return self.post is not None
|
| 410 |
+
|
| 411 |
+
@property
|
| 412 |
+
def is_devrelease(self) -> bool:
|
| 413 |
+
"""Whether this version is a development release.
|
| 414 |
+
|
| 415 |
+
>>> Version("1.2.3").is_devrelease
|
| 416 |
+
False
|
| 417 |
+
>>> Version("1.2.3.dev1").is_devrelease
|
| 418 |
+
True
|
| 419 |
+
"""
|
| 420 |
+
return self.dev is not None
|
| 421 |
+
|
| 422 |
+
@property
|
| 423 |
+
def major(self) -> int:
|
| 424 |
+
"""The first item of :attr:`release` or ``0`` if unavailable.
|
| 425 |
+
|
| 426 |
+
>>> Version("1.2.3").major
|
| 427 |
+
1
|
| 428 |
+
"""
|
| 429 |
+
return self.release[0] if len(self.release) >= 1 else 0
|
| 430 |
+
|
| 431 |
+
@property
|
| 432 |
+
def minor(self) -> int:
|
| 433 |
+
"""The second item of :attr:`release` or ``0`` if unavailable.
|
| 434 |
+
|
| 435 |
+
>>> Version("1.2.3").minor
|
| 436 |
+
2
|
| 437 |
+
>>> Version("1").minor
|
| 438 |
+
0
|
| 439 |
+
"""
|
| 440 |
+
return self.release[1] if len(self.release) >= 2 else 0
|
| 441 |
+
|
| 442 |
+
@property
|
| 443 |
+
def micro(self) -> int:
|
| 444 |
+
"""The third item of :attr:`release` or ``0`` if unavailable.
|
| 445 |
+
|
| 446 |
+
>>> Version("1.2.3").micro
|
| 447 |
+
3
|
| 448 |
+
>>> Version("1").micro
|
| 449 |
+
0
|
| 450 |
+
"""
|
| 451 |
+
return self.release[2] if len(self.release) >= 3 else 0
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class _TrimmedRelease(Version):
|
| 455 |
+
@property
|
| 456 |
+
def release(self) -> tuple[int, ...]:
|
| 457 |
+
"""
|
| 458 |
+
Release segment without any trailing zeros.
|
| 459 |
+
|
| 460 |
+
>>> _TrimmedRelease('1.0.0').release
|
| 461 |
+
(1,)
|
| 462 |
+
>>> _TrimmedRelease('0.0').release
|
| 463 |
+
(0,)
|
| 464 |
+
"""
|
| 465 |
+
rel = super().release
|
| 466 |
+
nonzeros = (index for index, val in enumerate(rel) if val)
|
| 467 |
+
last_nonzero = max(nonzeros, default=0)
|
| 468 |
+
return rel[: last_nonzero + 1]
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _parse_letter_version(
|
| 472 |
+
letter: str | None, number: str | bytes | SupportsInt | None
|
| 473 |
+
) -> tuple[str, int] | None:
|
| 474 |
+
if letter:
|
| 475 |
+
# We consider there to be an implicit 0 in a pre-release if there is
|
| 476 |
+
# not a numeral associated with it.
|
| 477 |
+
if number is None:
|
| 478 |
+
number = 0
|
| 479 |
+
|
| 480 |
+
# We normalize any letters to their lower case form
|
| 481 |
+
letter = letter.lower()
|
| 482 |
+
|
| 483 |
+
# We consider some words to be alternate spellings of other words and
|
| 484 |
+
# in those cases we want to normalize the spellings to our preferred
|
| 485 |
+
# spelling.
|
| 486 |
+
if letter == "alpha":
|
| 487 |
+
letter = "a"
|
| 488 |
+
elif letter == "beta":
|
| 489 |
+
letter = "b"
|
| 490 |
+
elif letter in ["c", "pre", "preview"]:
|
| 491 |
+
letter = "rc"
|
| 492 |
+
elif letter in ["rev", "r"]:
|
| 493 |
+
letter = "post"
|
| 494 |
+
|
| 495 |
+
return letter, int(number)
|
| 496 |
+
|
| 497 |
+
assert not letter
|
| 498 |
+
if number:
|
| 499 |
+
# We assume if we are given a number, but we are not given a letter
|
| 500 |
+
# then this is using the implicit post release syntax (e.g. 1.0-1)
|
| 501 |
+
letter = "post"
|
| 502 |
+
|
| 503 |
+
return letter, int(number)
|
| 504 |
+
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
_local_version_separators = re.compile(r"[\._-]")
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def _parse_local_version(local: str | None) -> LocalType | None:
|
| 512 |
+
"""
|
| 513 |
+
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
|
| 514 |
+
"""
|
| 515 |
+
if local is not None:
|
| 516 |
+
return tuple(
|
| 517 |
+
part.lower() if not part.isdigit() else int(part)
|
| 518 |
+
for part in _local_version_separators.split(local)
|
| 519 |
+
)
|
| 520 |
+
return None
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def _cmpkey(
|
| 524 |
+
epoch: int,
|
| 525 |
+
release: tuple[int, ...],
|
| 526 |
+
pre: tuple[str, int] | None,
|
| 527 |
+
post: tuple[str, int] | None,
|
| 528 |
+
dev: tuple[str, int] | None,
|
| 529 |
+
local: LocalType | None,
|
| 530 |
+
) -> CmpKey:
|
| 531 |
+
# When we compare a release version, we want to compare it with all of the
|
| 532 |
+
# trailing zeros removed. So we'll use a reverse the list, drop all the now
|
| 533 |
+
# leading zeros until we come to something non zero, then take the rest
|
| 534 |
+
# re-reverse it back into the correct order and make it a tuple and use
|
| 535 |
+
# that for our sorting key.
|
| 536 |
+
_release = tuple(
|
| 537 |
+
reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
|
| 541 |
+
# We'll do this by abusing the pre segment, but we _only_ want to do this
|
| 542 |
+
# if there is not a pre or a post segment. If we have one of those then
|
| 543 |
+
# the normal sorting rules will handle this case correctly.
|
| 544 |
+
if pre is None and post is None and dev is not None:
|
| 545 |
+
_pre: CmpPrePostDevType = NegativeInfinity
|
| 546 |
+
# Versions without a pre-release (except as noted above) should sort after
|
| 547 |
+
# those with one.
|
| 548 |
+
elif pre is None:
|
| 549 |
+
_pre = Infinity
|
| 550 |
+
else:
|
| 551 |
+
_pre = pre
|
| 552 |
+
|
| 553 |
+
# Versions without a post segment should sort before those with one.
|
| 554 |
+
if post is None:
|
| 555 |
+
_post: CmpPrePostDevType = NegativeInfinity
|
| 556 |
+
|
| 557 |
+
else:
|
| 558 |
+
_post = post
|
| 559 |
+
|
| 560 |
+
# Versions without a development segment should sort after those with one.
|
| 561 |
+
if dev is None:
|
| 562 |
+
_dev: CmpPrePostDevType = Infinity
|
| 563 |
+
|
| 564 |
+
else:
|
| 565 |
+
_dev = dev
|
| 566 |
+
|
| 567 |
+
if local is None:
|
| 568 |
+
# Versions without a local segment should sort before those with one.
|
| 569 |
+
_local: CmpLocalType = NegativeInfinity
|
| 570 |
+
else:
|
| 571 |
+
# Versions with a local segment need that segment parsed to implement
|
| 572 |
+
# the sorting rules in PEP440.
|
| 573 |
+
# - Alpha numeric segments sort before numeric segments
|
| 574 |
+
# - Alpha numeric segments sort lexicographically
|
| 575 |
+
# - Numeric segments sort numerically
|
| 576 |
+
# - Shorter versions sort before longer versions when the prefixes
|
| 577 |
+
# match exactly
|
| 578 |
+
_local = tuple(
|
| 579 |
+
(i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
return epoch, _release, _pre, _post, _dev, _local
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/cachecontrol/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: 2015 Eric Larson
|
| 2 |
+
#
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
|
| 5 |
+
"""CacheControl import Interface.
|
| 6 |
+
|
| 7 |
+
Make it easy to import from cachecontrol without long namespaces.
|
| 8 |
+
"""
|
| 9 |
+
__author__ = "Eric Larson"
|
| 10 |
+
__email__ = "eric@ionrock.org"
|
| 11 |
+
__version__ = "0.14.0"
|
| 12 |
+
|
| 13 |
+
from pip._vendor.cachecontrol.adapter import CacheControlAdapter
|
| 14 |
+
from pip._vendor.cachecontrol.controller import CacheController
|
| 15 |
+
from pip._vendor.cachecontrol.wrapper import CacheControl
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"__author__",
|
| 19 |
+
"__email__",
|
| 20 |
+
"__version__",
|
| 21 |
+
"CacheControlAdapter",
|
| 22 |
+
"CacheController",
|
| 23 |
+
"CacheControl",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
import logging
|
| 27 |
+
|
| 28 |
+
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/cachecontrol/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (999 Bytes). View file
|
|
|