Spaces:
Sleeping
Sleeping
| # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import ctypes | |
| import numpy | |
| import warp | |
| # return the warp device corresponding to a torch device | |
| def device_from_torch(torch_device): | |
| """Return the warp device corresponding to a torch device.""" | |
| return warp.get_device(str(torch_device)) | |
| def device_to_torch(wp_device): | |
| """Return the torch device corresponding to a warp device.""" | |
| device = warp.get_device(wp_device) | |
| if device.is_cpu or device.is_primary: | |
| return str(device) | |
| elif device.is_cuda and device.is_uva: | |
| # it's not a primary context, but torch can access the data ptr directly thanks to UVA | |
| return f"cuda:{device.ordinal}" | |
| raise RuntimeError(f"Warp device {device} is not compatible with torch") | |
| def dtype_from_torch(torch_dtype): | |
| """Return the Warp dtype corresponding to a torch dtype.""" | |
| # initialize lookup table on first call to defer torch import | |
| if dtype_from_torch.type_map is None: | |
| import torch | |
| dtype_from_torch.type_map = { | |
| torch.float64: warp.float64, | |
| torch.float32: warp.float32, | |
| torch.float16: warp.float16, | |
| torch.int64: warp.int64, | |
| torch.int32: warp.int32, | |
| torch.int16: warp.int16, | |
| torch.int8: warp.int8, | |
| torch.uint8: warp.uint8, | |
| torch.bool: warp.bool, | |
| # currently unsupported by Warp | |
| # torch.bfloat16: | |
| # torch.complex64: | |
| # torch.complex128: | |
| } | |
| warp_dtype = dtype_from_torch.type_map.get(torch_dtype) | |
| if warp_dtype is not None: | |
| return warp_dtype | |
| else: | |
| raise TypeError(f"Invalid or unsupported data type: {torch_dtype}") | |
| dtype_from_torch.type_map = None | |
| def dtype_is_compatible(torch_dtype, warp_dtype): | |
| """Evaluates whether the given torch dtype is compatible with the given warp dtype.""" | |
| # initialize lookup table on first call to defer torch import | |
| if dtype_is_compatible.compatible_sets is None: | |
| import torch | |
| dtype_is_compatible.compatible_sets = { | |
| torch.float64: {warp.float64}, | |
| torch.float32: {warp.float32}, | |
| torch.float16: {warp.float16}, | |
| # allow aliasing integer tensors as signed or unsigned integer arrays | |
| torch.int64: {warp.int64, warp.uint64}, | |
| torch.int32: {warp.int32, warp.uint32}, | |
| torch.int16: {warp.int16, warp.uint16}, | |
| torch.int8: {warp.int8, warp.uint8}, | |
| torch.uint8: {warp.uint8, warp.int8}, | |
| torch.bool: {warp.bool, warp.uint8, warp.int8}, | |
| # currently unsupported by Warp | |
| # torch.bfloat16: | |
| # torch.complex64: | |
| # torch.complex128: | |
| } | |
| compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype) | |
| if compatible_set is not None: | |
| if hasattr(warp_dtype, "_wp_scalar_type_"): | |
| return warp_dtype._wp_scalar_type_ in compatible_set | |
| else: | |
| return warp_dtype in compatible_set | |
| else: | |
| raise TypeError(f"Invalid or unsupported data type: {torch_dtype}") | |
| dtype_is_compatible.compatible_sets = None | |
| # wrap a torch tensor to a wp array, data is not copied | |
| def from_torch(t, dtype=None, requires_grad=None, grad=None): | |
| """Wrap a PyTorch tensor to a Warp array without copying the data. | |
| Args: | |
| t (torch.Tensor): The torch tensor to wrap. | |
| dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type. | |
| requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value. | |
| Returns: | |
| warp.array: The wrapped array. | |
| """ | |
| if dtype is None: | |
| dtype = dtype_from_torch(t.dtype) | |
| elif not dtype_is_compatible(t.dtype, dtype): | |
| raise RuntimeError(f"Incompatible data types: {t.dtype} and {dtype}") | |
| # get size of underlying data type to compute strides | |
| ctype_size = ctypes.sizeof(dtype._type_) | |
| shape = tuple(t.shape) | |
| strides = tuple(s * ctype_size for s in t.stride()) | |
| # if target is a vector or matrix type | |
| # then check if trailing dimensions match | |
| # the target type and update the shape | |
| if hasattr(dtype, "_shape_"): | |
| dtype_shape = dtype._shape_ | |
| dtype_dims = len(dtype._shape_) | |
| if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]: | |
| raise RuntimeError( | |
| f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}" | |
| ) | |
| # ensure the inner strides are contiguous | |
| stride = ctype_size | |
| for i in range(dtype_dims): | |
| if strides[-i - 1] != stride: | |
| raise RuntimeError( | |
| f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous" | |
| ) | |
| stride *= dtype_shape[-i - 1] | |
| shape = tuple(shape[:-dtype_dims]) or (1,) | |
| strides = tuple(strides[:-dtype_dims]) or (ctype_size,) | |
| requires_grad = t.requires_grad if requires_grad is None else requires_grad | |
| if grad is not None: | |
| if not isinstance(grad, warp.array): | |
| import torch | |
| if isinstance(grad, torch.Tensor): | |
| grad = from_torch(grad, dtype=dtype) | |
| else: | |
| raise ValueError(f"Invalid gradient type: {type(grad)}") | |
| elif requires_grad: | |
| # wrap the tensor gradient, allocate if necessary | |
| if t.grad is None: | |
| # allocate a zero-filled gradient tensor if it doesn't exist | |
| import torch | |
| t.grad = torch.zeros_like(t, requires_grad=False) | |
| grad = from_torch(t.grad, dtype=dtype) | |
| a = warp.types.array( | |
| ptr=t.data_ptr(), | |
| dtype=dtype, | |
| shape=shape, | |
| strides=strides, | |
| device=device_from_torch(t.device), | |
| copy=False, | |
| owner=False, | |
| grad=grad, | |
| requires_grad=requires_grad, | |
| ) | |
| # save a reference to the source tensor, otherwise it will be deallocated | |
| a._tensor = t | |
| return a | |
| def to_torch(a, requires_grad=None): | |
| """ | |
| Convert a Warp array to a PyTorch tensor without copying the data. | |
| Args: | |
| a (warp.array): The Warp array to convert. | |
| requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value. | |
| Returns: | |
| torch.Tensor: The converted tensor. | |
| """ | |
| import torch | |
| if requires_grad is None: | |
| requires_grad = a.requires_grad | |
| # Torch does not support structured arrays | |
| if isinstance(a.dtype, warp.codegen.Struct): | |
| raise RuntimeError("Cannot convert structured Warp arrays to Torch.") | |
| if a.device.is_cpu: | |
| # Torch has an issue wrapping CPU objects | |
| # that support the __array_interface__ protocol | |
| # in this case we need to workaround by going | |
| # to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html | |
| t = torch.as_tensor(numpy.asarray(a)) | |
| t.requires_grad = requires_grad | |
| if requires_grad and a.requires_grad: | |
| t.grad = torch.as_tensor(numpy.asarray(a.grad)) | |
| return t | |
| elif a.device.is_cuda: | |
| # Torch does support the __cuda_array_interface__ | |
| # correctly, but we must be sure to maintain a reference | |
| # to the owning object to prevent memory allocs going out of scope | |
| t = torch.as_tensor(a, device=device_to_torch(a.device)) | |
| t.requires_grad = requires_grad | |
| if requires_grad and a.requires_grad: | |
| t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device)) | |
| return t | |
| else: | |
| raise RuntimeError("Unsupported device") | |
| def stream_from_torch(stream_or_device=None): | |
| """Convert from a PyTorch CUDA stream to a Warp.Stream.""" | |
| import torch | |
| if isinstance(stream_or_device, torch.cuda.Stream): | |
| stream = stream_or_device | |
| else: | |
| # assume arg is a torch device | |
| stream = torch.cuda.current_stream(stream_or_device) | |
| device = device_from_torch(stream.device) | |
| warp_stream = warp.Stream(device, cuda_stream=stream.cuda_stream) | |
| # save a reference to the source stream, otherwise it may be destroyed | |
| warp_stream._torch_stream = stream | |
| return warp_stream | |
| def stream_to_torch(stream_or_device=None): | |
| """Convert from a Warp.Stream to a PyTorch CUDA stream.""" | |
| import torch | |
| if isinstance(stream_or_device, warp.Stream): | |
| stream = stream_or_device | |
| else: | |
| # assume arg is a warp device | |
| stream = warp.get_device(stream_or_device).stream | |
| device = device_to_torch(stream.device) | |
| torch_stream = torch.cuda.ExternalStream(stream.cuda_stream, device=device) | |
| # save a reference to the source stream, otherwise it may be destroyed | |
| torch_stream._warp_stream = stream | |
| return torch_stream | |