Spaces:
Runtime error
Runtime error
import os | |
import re | |
import copy | |
import torch | |
import numpy as np | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
from shutil import copyfile | |
def get_proper_cuda_device(device, verbose=True): | |
if not isinstance(device, list): | |
device = [device] | |
count = torch.cuda.device_count() | |
if verbose: | |
print("[Builder]: Found {} gpu".format(count)) | |
for i in range(len(device)): | |
d = device[i] | |
did = None | |
if isinstance(d, str): | |
if re.search("cuda:[\d]+", d): | |
did = int(d[5:]) | |
elif isinstance(d, int): | |
did = d | |
if did is None: | |
raise ValueError("[Builder]: Wrong cuda id {}".format(d)) | |
if did < 0 or did >= count: | |
if verbose: | |
print("[Builder]: {} is not found, ignore.".format(d)) | |
device[i] = None | |
else: | |
device[i] = did | |
device = [d for d in device if d is not None] | |
return device | |
def get_proper_device(devices, verbose=True): | |
origin = copy.copy(devices) | |
devices = copy.copy(devices) | |
if not isinstance(devices, list): | |
devices = [devices] | |
use_cpu = any([d.find("cpu")>=0 for d in devices]) | |
use_gpu = any([(d.find("cuda")>=0 or isinstance(d, int)) for d in devices]) | |
assert not (use_cpu and use_gpu), "{} contains cpu and cuda device.".format(devices) | |
if use_gpu: | |
devices = get_proper_cuda_device(devices, verbose) | |
if len(devices) == 0: | |
if verbose: | |
print("[Builder]: Failed to find any valid gpu in {}, use `cpu`.".format(origin)) | |
devices = ["cpu"] | |
return devices | |
def _file_at_step(step): | |
return "save_{}k{}.pkg".format(int(step // 1000), int(step % 1000)) | |
def _file_best(): | |
return "trained.pkg" | |
def save(global_step, graph, optim, criterion_dict=None, pkg_dir="", is_best=False, max_checkpoints=None): | |
if optim is None: | |
raise ValueError("cannot save without optimzier") | |
state = { | |
"global_step": global_step, | |
# DataParallel wrap model in attr `module`. | |
"graph": graph.module.state_dict() if hasattr(graph, "module") else graph.state_dict(), | |
"optim": optim.state_dict(), | |
"criterion": {} | |
} | |
if criterion_dict is not None: | |
for k in criterion_dict: | |
state["criterion"][k] = criterion_dict[k].state_dict() | |
save_path = os.path.join(pkg_dir, _file_at_step(global_step)) | |
best_path = os.path.join(pkg_dir, _file_best()) | |
torch.save(state, save_path) | |
if is_best: | |
copyfile(save_path, best_path) | |
if max_checkpoints is not None: | |
history = [] | |
for file_name in os.listdir(pkg_dir): | |
if re.search("save_\d*k\d*\.pkg", file_name): | |
digits = file_name.replace("save_", "").replace(".pkg", "").split("k") | |
number = int(digits[0]) * 1000 + int(digits[1]) | |
history.append(number) | |
history.sort() | |
while len(history) > max_checkpoints: | |
path = os.path.join(pkg_dir, _file_at_step(history[0])) | |
print("[Checkpoint]: remove {} to keep {} checkpoints".format(path, max_checkpoints)) | |
if os.path.exists(path): | |
os.remove(path) | |
history.pop(0) | |
def load(step_or_path, graph, optim=None, criterion_dict=None, pkg_dir="", device=None): | |
step = step_or_path | |
save_path = None | |
print("LOADING FROM pkg_dir: " + pkg_dir) | |
if isinstance(step, int): | |
save_path = os.path.join(pkg_dir, _file_at_step(step)) | |
if isinstance(step, str): | |
if pkg_dir is not None: | |
if step == "best": | |
save_path = os.path.join(pkg_dir, _file_best()) | |
else: | |
save_path = os.path.join(pkg_dir, step) | |
else: | |
save_path = step | |
if save_path is not None and not os.path.exists(save_path): | |
print("[Checkpoint]: Failed to find {}".format(save_path)) | |
return | |
if save_path is None: | |
print("[Checkpoint]: Cannot load the checkpoint with given step or filename or `best`") | |
return | |
# begin to load | |
state = torch.load(save_path, map_location=device) | |
global_step = state["global_step"] | |
graph.load_state_dict(state["graph"]) | |
if optim is not None: | |
optim.load_state_dict(state["optim"]) | |
if criterion_dict is not None: | |
for k in criterion_dict: | |
criterion_dict[k].load_state_dict(state["criterion"][k]) | |
graph.set_actnorm_init(inited=True) | |
print("[Checkpoint]: Load {} successfully".format(save_path)) | |
return global_step | |
def __save_figure_to_numpy(fig): | |
# save it to a numpy array. | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
def __to_ndarray_list(tensors, titles): | |
if not isinstance(tensors, list): | |
tensors = [tensors] | |
titles = [titles] | |
assert len(titles) == len(tensors),\ | |
"[visualizer]: {} titles are not enough for {} tensors".format( | |
len(titles), len(tensors)) | |
for i in range(len(tensors)): | |
if torch.is_tensor(tensors[i]): | |
tensors[i] = tensors[i].cpu().detach().numpy() | |
return tensors, titles | |
def __get_figures(num_tensors, figsize): | |
fig, axes = plt.subplots(num_tensors, 1, figsize=figsize) | |
if not isinstance(axes, np.ndarray): | |
axes = np.asarray([axes]) | |
return fig, axes | |
def __make_dir(file_name, plot_dir): | |
if file_name is not None and not os.path.exists(plot_dir): | |
os.makedirs(plot_dir) | |
def __draw(fig, file_name, plot_dir): | |
if file_name is not None: | |
plt.savefig('{}/{}.png'.format(plot_dir, file_name), format='png') | |
plt.close(fig) | |
return None | |
else: | |
fig.tight_layout() | |
fig.canvas.draw() | |
data = __save_figure_to_numpy(fig) | |
plt.close(fig) | |
return data | |
def __prepare_cond(autoreg, control, data_device): | |
nn,seqlen,n_feats = autoreg.shape | |
autoreg = autoreg.reshape((nn, seqlen*n_feats)) | |
nn,seqlen,n_feats = control.shape | |
control = control.reshape((nn, seqlen*n_feats)) | |
cond = torch.from_numpy(np.expand_dims(np.concatenate((autoreg,control),axis=1), axis=-1)) | |
return cond.to(data_device) | |
def __generate_sample(graph, data_batch, device, eps_std=1.0): | |
print("generate_sample") | |
seqlen = data_batch["seqlen"].cpu()[0].numpy() | |
fps = data_batch["frame_rate"].cpu()[0].numpy() | |
autoreg_all = data_batch["autoreg"].cpu().numpy() | |
control_all = data_batch["control"].cpu().numpy() | |
print("autoreg_all: " +str(autoreg_all.shape)) | |
autoreg = autoreg_all[:,:seqlen,:] | |
if hasattr(graph, "module"): | |
graph.module.init_lstm_hidden() | |
else: | |
graph.init_lstm_hidden() | |
sampled_all = np.zeros(autoreg_all.shape) | |
sampled_all[:,:seqlen,:] = autoreg_all[:,:seqlen,:] | |
autoreg = autoreg_all[:,:seqlen,:] | |
for i in range(0,control_all.shape[1]-seqlen): | |
control = control_all[:,i:(i+seqlen+1),:] | |
cond = __prepare_cond(autoreg, control, device) | |
sampled = graph(z=None, cond=cond, eps_std=eps_std, reverse=True) | |
sampled = sampled.cpu().numpy() | |
sampled_all[:,(i+seqlen),:] = sampled[:,:,0] | |
autoreg = np.concatenate((autoreg[:,1:,:], sampled.swapaxes(1,2)), axis=1) | |
anim_clip = np.concatenate((sampled_all, control_all), axis=2) | |
return anim_clip | |
def __get_size_for_spec(tensors): | |
spectrogram = tensors[0] | |
fig_w = np.min([int(np.ceil(spectrogram.shape[1] / 10.0)), 10]) | |
fig_w = np.max([fig_w, 3]) | |
fig_h = np.max([3 * len(tensors), 3]) | |
return (fig_w, fig_h) | |
def __get_aspect(spectrogram): | |
fig_w = np.min([int(np.ceil(spectrogram.shape[1] / 10.0)), 10]) | |
fig_w = np.max([fig_w, 3]) | |
aspect = 3.0 / fig_w | |
if spectrogram.shape[1] > 50: | |
aspect = aspect * spectrogram.shape[1] / spectrogram.shape[0] | |
else: | |
aspect = aspect * spectrogram.shape[1] / (spectrogram.shape[0]) | |
return aspect | |
def plot_prob(done, title="", file_name=None, plot_dir=None): | |
__make_dir(file_name, plot_dir) | |
done, title = __to_ndarray_list(done, title) | |
for i in range(len(done)): | |
done[i] = np.reshape(done[i], (-1, done[i].shape[-1])) | |
figsize = (5, 5 * len(done)) | |
fig, axes = __get_figures(len(done), figsize) | |
for ax, d, t in zip(axes, done, title): | |
im = ax.imshow(d, vmin=0, vmax=1, cmap="Blues", aspect=d.shape[1]/d.shape[0]) | |
ax.set_title(t) | |
ax.set_yticks(np.arange(d.shape[0])) | |
lables = ["Frame{}".format(i+1) for i in range(d.shape[0])] | |
ax.set_yticklabels(lables) | |
ax.set_yticks(np.arange(d.shape[0])-.5, minor=True) | |
ax.grid(which="minor", color="g", linestyle='-.', linewidth=1) | |
ax.invert_yaxis() | |
return __draw(fig, file_name, plot_dir) | |