from typing import Dict, List, Union | |
import numpy as np | |
import torch | |
def detach_to_numpy(data: Union[List, Dict, torch.Tensor]) -> Union[List, Dict, torch.Tensor]: | |
""" | |
Recursively detach elements in data | |
""" | |
if isinstance(data, torch.Tensor): | |
return data.cpu().detach().numpy() # pytype: disable=attribute-error | |
elif isinstance(data, np.ndarray): | |
return data | |
elif isinstance(data, list): | |
return [detach_to_numpy(d) for d in data] | |
elif isinstance(data, dict): | |
for k in data.keys(): | |
data[k] = detach_to_numpy(data[k]) | |
return data | |
else: | |
raise ValueError("data should be tensor, numpy array, dict, or list.") | |