salad-demo / salad /utils /nputil.py
DveloperY0115's picture
init repo
801501a
raw
history blame
308 Bytes
import numpy as np
import torch
def np2th(ndarray):
if isinstance(ndarray, torch.Tensor):
return ndarray.detach().cpu()
elif isinstance(ndarray, np.ndarray):
return torch.tensor(ndarray).float()
else:
raise ValueError("Input should be either torch.Tensor or np.ndarray")