Spaces:
Runtime error
Runtime error
# decorator | |
import torch | |
from numbers import Number | |
import inspect | |
from functools import wraps | |
def get_device(args, kwargs): | |
device = None | |
for arg in (list(args) + list(kwargs.values())): | |
if isinstance(arg, torch.Tensor): | |
if device is None: | |
device = arg.device | |
elif device != arg.device: | |
raise ValueError("All tensors must be on the same device.") | |
return device | |
def get_args_order(func, args, kwargs): | |
""" | |
Get the order of the arguments of a function. | |
""" | |
names = inspect.getfullargspec(func).args | |
names_idx = {name: i for i, name in enumerate(names)} | |
args_order = [] | |
kwargs_order = {} | |
for name, arg in kwargs.items(): | |
if name in names: | |
kwargs_order[name] = names_idx[name] | |
names.remove(name) | |
for i, arg in enumerate(args): | |
if i < len(names): | |
args_order.append(names_idx[names[i]]) | |
return args_order, kwargs_order | |
def broadcast_args(args, kwargs, args_dim, kwargs_dim): | |
spatial = [] | |
for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): | |
if isinstance(arg, torch.Tensor) and arg_dim is not None: | |
arg_spatial = arg.shape[:arg.ndim-arg_dim] | |
if len(arg_spatial) > len(spatial): | |
spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial | |
for j in range(len(arg_spatial)): | |
if spatial[-j] < arg_spatial[-j]: | |
if spatial[-j] == 1: | |
spatial[-j] = arg_spatial[-j] | |
else: | |
raise ValueError("Cannot broadcast arguments.") | |
for i, arg in enumerate(args): | |
if isinstance(arg, torch.Tensor) and args_dim[i] is not None: | |
args[i] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) | |
for key, arg in kwargs.items(): | |
if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: | |
kwargs[key] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) | |
return args, kwargs, spatial | |
def batched(*dims): | |
""" | |
Decorator that allows a function to be called with batched arguments. | |
""" | |
def decorator(func): | |
def wrapper(*args, device=torch.device('cpu'), **kwargs): | |
args = list(args) | |
# get arguments dimensions | |
args_order, kwargs_order = get_args_order(func, args, kwargs) | |
args_dim = [dims[i] for i in args_order] | |
kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} | |
# convert to torch tensor | |
device = get_device(args, kwargs) or device | |
for i, arg in enumerate(args): | |
if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: | |
args[i] = torch.tensor(arg, device=device) | |
for key, arg in kwargs.items(): | |
if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: | |
kwargs[key] = torch.tensor(arg, device=device) | |
# broadcast arguments | |
args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) | |
for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): | |
if isinstance(arg, torch.Tensor) and arg_dim is not None: | |
args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) | |
for key, arg in kwargs.items(): | |
if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: | |
kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) | |
# call function | |
results = func(*args, **kwargs) | |
type_results = type(results) | |
results = list(results) if isinstance(results, (tuple, list)) else [results] | |
# restore spatial dimensions | |
for i, result in enumerate(results): | |
results[i] = result.reshape([*spatial, *result.shape[1:]]) | |
if type_results == tuple: | |
results = tuple(results) | |
elif type_results == list: | |
results = list(results) | |
else: | |
results = results[0] | |
return results | |
return wrapper | |
return decorator |