File size: 308 Bytes
801501a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import numpy as np
import torch

def np2th(ndarray):
    if isinstance(ndarray, torch.Tensor):
        return ndarray.detach().cpu()
    elif isinstance(ndarray, np.ndarray):
        return torch.tensor(ndarray).float()
    else:
        raise ValueError("Input should be either torch.Tensor or np.ndarray")