|
import torch
|
|
|
|
|
|
def to_numpy(tensor):
|
|
if torch.is_tensor(tensor):
|
|
return tensor.cpu().numpy()
|
|
elif type(tensor).__module__ != 'numpy':
|
|
raise ValueError("Cannot convert {} to numpy array".format(
|
|
type(tensor)))
|
|
return tensor
|
|
|
|
|
|
def to_torch(ndarray):
|
|
if type(ndarray).__module__ == 'numpy':
|
|
return torch.from_numpy(ndarray)
|
|
elif not torch.is_tensor(ndarray):
|
|
raise ValueError("Cannot convert {} to torch tensor".format(
|
|
type(ndarray)))
|
|
return ndarray
|
|
|
|
|
|
def cleanexit():
|
|
import sys
|
|
import os
|
|
try:
|
|
sys.exit(0)
|
|
except SystemExit:
|
|
os._exit(0)
|
|
|
|
def load_model_wo_clip(model, state_dict):
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|