PeterYu's picture
update
2875fe6
import os
import numpy as np
import pickle
import torch
import random
from datetime import datetime
def pkl_save(name, var):
with open(name, "wb") as f:
pickle.dump(var, f)
def pkl_load(name):
with open(name, "rb") as f:
return pickle.load(f)
def torch_pad_nan(arr, left=0, right=0, dim=0):
if left > 0:
padshape = list(arr.shape)
padshape[dim] = left
arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim)
if right > 0:
padshape = list(arr.shape)
padshape[dim] = right
arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim)
return arr
def pad_nan_to_target(array, target_length, axis=0, both_side=False):
assert array.dtype in [np.float16, np.float32, np.float64]
pad_size = target_length - array.shape[axis]
if pad_size <= 0:
return array
npad = [(0, 0)] * array.ndim
if both_side:
npad[axis] = (pad_size // 2, pad_size - pad_size // 2)
else:
npad[axis] = (0, pad_size)
return np.pad(array, pad_width=npad, mode="constant", constant_values=np.nan)
def split_with_nan(x, sections, axis=0):
assert x.dtype in [np.float16, np.float32, np.float64]
arrs = np.array_split(x, sections, axis=axis)
target_length = arrs[0].shape[axis]
for i in range(len(arrs)):
arrs[i] = pad_nan_to_target(arrs[i], target_length, axis=axis)
return arrs
def take_per_row(A, indx, num_elem):
all_indx = indx[:, None] + np.arange(num_elem)
return A[torch.arange(all_indx.shape[0])[:, None], all_indx]
def centerize_vary_length_series(x):
prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1)
suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1)
offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros
rows, column_indices = np.ogrid[: x.shape[0], : x.shape[1]]
offset[offset < 0] += x.shape[1]
column_indices = column_indices - offset[:, np.newaxis]
return x[rows, column_indices]
def data_dropout(arr, p):
B, T = arr.shape[0], arr.shape[1]
mask = np.full(B * T, False, dtype=np.bool)
ele_sel = np.random.choice(B * T, size=int(B * T * p), replace=False)
mask[ele_sel] = True
res = arr.copy()
res[mask.reshape(B, T)] = np.nan
return res
def name_with_datetime(prefix="default"):
now = datetime.now()
return prefix + "_" + now.strftime("%Y%m%d_%H%M%S")
def init_dl_program(
device_name,
seed=None,
use_cudnn=True,
deterministic=False,
benchmark=False,
use_tf32=False,
max_threads=None,
):
import torch
if max_threads is not None:
torch.set_num_threads(max_threads) # intraop
if torch.get_num_interop_threads() != max_threads:
torch.set_num_interop_threads(max_threads) # interop
try:
import mkl
except:
pass
else:
mkl.set_num_threads(max_threads)
if seed is not None:
random.seed(seed)
seed += 1
np.random.seed(seed)
seed += 1
torch.manual_seed(seed)
if isinstance(device_name, (str, int)):
device_name = [device_name]
devices = []
for t in reversed(device_name):
t_device = torch.device(t)
devices.append(t_device)
if t_device.type == "cuda":
assert torch.cuda.is_available()
torch.cuda.set_device(t_device)
if seed is not None:
seed += 1
torch.cuda.manual_seed(seed)
devices.reverse()
torch.backends.cudnn.enabled = use_cudnn
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = benchmark
if hasattr(torch.backends.cudnn, "allow_tf32"):
torch.backends.cudnn.allow_tf32 = use_tf32
torch.backends.cuda.matmul.allow_tf32 = use_tf32
return devices if len(devices) > 1 else devices[0]