Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import copy | |
| import torch | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from shutil import copyfile | |
| def get_proper_cuda_device(device, verbose=True): | |
| if not isinstance(device, list): | |
| device = [device] | |
| count = torch.cuda.device_count() | |
| if verbose: | |
| print("[Builder]: Found {} gpu".format(count)) | |
| for i in range(len(device)): | |
| d = device[i] | |
| did = None | |
| if isinstance(d, str): | |
| if re.search("cuda:[\d]+", d): | |
| did = int(d[5:]) | |
| elif isinstance(d, int): | |
| did = d | |
| if did is None: | |
| raise ValueError("[Builder]: Wrong cuda id {}".format(d)) | |
| if did < 0 or did >= count: | |
| if verbose: | |
| print("[Builder]: {} is not found, ignore.".format(d)) | |
| device[i] = None | |
| else: | |
| device[i] = did | |
| device = [d for d in device if d is not None] | |
| return device | |
| def get_proper_device(devices, verbose=True): | |
| origin = copy.copy(devices) | |
| devices = copy.copy(devices) | |
| if not isinstance(devices, list): | |
| devices = [devices] | |
| use_cpu = any([d.find("cpu")>=0 for d in devices]) | |
| use_gpu = any([(d.find("cuda")>=0 or isinstance(d, int)) for d in devices]) | |
| assert not (use_cpu and use_gpu), "{} contains cpu and cuda device.".format(devices) | |
| if use_gpu: | |
| devices = get_proper_cuda_device(devices, verbose) | |
| if len(devices) == 0: | |
| if verbose: | |
| print("[Builder]: Failed to find any valid gpu in {}, use `cpu`.".format(origin)) | |
| devices = ["cpu"] | |
| return devices | |
| def _file_at_step(step): | |
| return "save_{}k{}.pkg".format(int(step // 1000), int(step % 1000)) | |
| def _file_best(): | |
| return "trained.pkg" | |
| def save(global_step, graph, optim, criterion_dict=None, pkg_dir="", is_best=False, max_checkpoints=None): | |
| if optim is None: | |
| raise ValueError("cannot save without optimzier") | |
| state = { | |
| "global_step": global_step, | |
| # DataParallel wrap model in attr `module`. | |
| "graph": graph.module.state_dict() if hasattr(graph, "module") else graph.state_dict(), | |
| "optim": optim.state_dict(), | |
| "criterion": {} | |
| } | |
| if criterion_dict is not None: | |
| for k in criterion_dict: | |
| state["criterion"][k] = criterion_dict[k].state_dict() | |
| save_path = os.path.join(pkg_dir, _file_at_step(global_step)) | |
| best_path = os.path.join(pkg_dir, _file_best()) | |
| torch.save(state, save_path) | |
| if is_best: | |
| copyfile(save_path, best_path) | |
| if max_checkpoints is not None: | |
| history = [] | |
| for file_name in os.listdir(pkg_dir): | |
| if re.search("save_\d*k\d*\.pkg", file_name): | |
| digits = file_name.replace("save_", "").replace(".pkg", "").split("k") | |
| number = int(digits[0]) * 1000 + int(digits[1]) | |
| history.append(number) | |
| history.sort() | |
| while len(history) > max_checkpoints: | |
| path = os.path.join(pkg_dir, _file_at_step(history[0])) | |
| print("[Checkpoint]: remove {} to keep {} checkpoints".format(path, max_checkpoints)) | |
| if os.path.exists(path): | |
| os.remove(path) | |
| history.pop(0) | |
| def load(step_or_path, graph, optim=None, criterion_dict=None, pkg_dir="", device=None): | |
| step = step_or_path | |
| save_path = None | |
| print("LOADING FROM pkg_dir: " + pkg_dir) | |
| if isinstance(step, int): | |
| save_path = os.path.join(pkg_dir, _file_at_step(step)) | |
| if isinstance(step, str): | |
| if pkg_dir is not None: | |
| if step == "best": | |
| save_path = os.path.join(pkg_dir, _file_best()) | |
| else: | |
| save_path = os.path.join(pkg_dir, step) | |
| else: | |
| save_path = step | |
| if save_path is not None and not os.path.exists(save_path): | |
| print("[Checkpoint]: Failed to find {}".format(save_path)) | |
| return | |
| if save_path is None: | |
| print("[Checkpoint]: Cannot load the checkpoint with given step or filename or `best`") | |
| return | |
| # begin to load | |
| state = torch.load(save_path, map_location=device) | |
| global_step = state["global_step"] | |
| graph.load_state_dict(state["graph"]) | |
| if optim is not None: | |
| optim.load_state_dict(state["optim"]) | |
| if criterion_dict is not None: | |
| for k in criterion_dict: | |
| criterion_dict[k].load_state_dict(state["criterion"][k]) | |
| graph.set_actnorm_init(inited=True) | |
| print("[Checkpoint]: Load {} successfully".format(save_path)) | |
| return global_step | |
| def __save_figure_to_numpy(fig): | |
| # save it to a numpy array. | |
| data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return data | |
| def __to_ndarray_list(tensors, titles): | |
| if not isinstance(tensors, list): | |
| tensors = [tensors] | |
| titles = [titles] | |
| assert len(titles) == len(tensors),\ | |
| "[visualizer]: {} titles are not enough for {} tensors".format( | |
| len(titles), len(tensors)) | |
| for i in range(len(tensors)): | |
| if torch.is_tensor(tensors[i]): | |
| tensors[i] = tensors[i].cpu().detach().numpy() | |
| return tensors, titles | |
| def __get_figures(num_tensors, figsize): | |
| fig, axes = plt.subplots(num_tensors, 1, figsize=figsize) | |
| if not isinstance(axes, np.ndarray): | |
| axes = np.asarray([axes]) | |
| return fig, axes | |
| def __make_dir(file_name, plot_dir): | |
| if file_name is not None and not os.path.exists(plot_dir): | |
| os.makedirs(plot_dir) | |
| def __draw(fig, file_name, plot_dir): | |
| if file_name is not None: | |
| plt.savefig('{}/{}.png'.format(plot_dir, file_name), format='png') | |
| plt.close(fig) | |
| return None | |
| else: | |
| fig.tight_layout() | |
| fig.canvas.draw() | |
| data = __save_figure_to_numpy(fig) | |
| plt.close(fig) | |
| return data | |
| def __prepare_cond(autoreg, control, data_device): | |
| nn,seqlen,n_feats = autoreg.shape | |
| autoreg = autoreg.reshape((nn, seqlen*n_feats)) | |
| nn,seqlen,n_feats = control.shape | |
| control = control.reshape((nn, seqlen*n_feats)) | |
| cond = torch.from_numpy(np.expand_dims(np.concatenate((autoreg,control),axis=1), axis=-1)) | |
| return cond.to(data_device) | |
| def __generate_sample(graph, data_batch, device, eps_std=1.0): | |
| print("generate_sample") | |
| seqlen = data_batch["seqlen"].cpu()[0].numpy() | |
| fps = data_batch["frame_rate"].cpu()[0].numpy() | |
| autoreg_all = data_batch["autoreg"].cpu().numpy() | |
| control_all = data_batch["control"].cpu().numpy() | |
| print("autoreg_all: " +str(autoreg_all.shape)) | |
| autoreg = autoreg_all[:,:seqlen,:] | |
| if hasattr(graph, "module"): | |
| graph.module.init_lstm_hidden() | |
| else: | |
| graph.init_lstm_hidden() | |
| sampled_all = np.zeros(autoreg_all.shape) | |
| sampled_all[:,:seqlen,:] = autoreg_all[:,:seqlen,:] | |
| autoreg = autoreg_all[:,:seqlen,:] | |
| for i in range(0,control_all.shape[1]-seqlen): | |
| control = control_all[:,i:(i+seqlen+1),:] | |
| cond = __prepare_cond(autoreg, control, device) | |
| sampled = graph(z=None, cond=cond, eps_std=eps_std, reverse=True) | |
| sampled = sampled.cpu().numpy() | |
| sampled_all[:,(i+seqlen),:] = sampled[:,:,0] | |
| autoreg = np.concatenate((autoreg[:,1:,:], sampled.swapaxes(1,2)), axis=1) | |
| anim_clip = np.concatenate((sampled_all, control_all), axis=2) | |
| return anim_clip | |
| def __get_size_for_spec(tensors): | |
| spectrogram = tensors[0] | |
| fig_w = np.min([int(np.ceil(spectrogram.shape[1] / 10.0)), 10]) | |
| fig_w = np.max([fig_w, 3]) | |
| fig_h = np.max([3 * len(tensors), 3]) | |
| return (fig_w, fig_h) | |
| def __get_aspect(spectrogram): | |
| fig_w = np.min([int(np.ceil(spectrogram.shape[1] / 10.0)), 10]) | |
| fig_w = np.max([fig_w, 3]) | |
| aspect = 3.0 / fig_w | |
| if spectrogram.shape[1] > 50: | |
| aspect = aspect * spectrogram.shape[1] / spectrogram.shape[0] | |
| else: | |
| aspect = aspect * spectrogram.shape[1] / (spectrogram.shape[0]) | |
| return aspect | |
| def plot_prob(done, title="", file_name=None, plot_dir=None): | |
| __make_dir(file_name, plot_dir) | |
| done, title = __to_ndarray_list(done, title) | |
| for i in range(len(done)): | |
| done[i] = np.reshape(done[i], (-1, done[i].shape[-1])) | |
| figsize = (5, 5 * len(done)) | |
| fig, axes = __get_figures(len(done), figsize) | |
| for ax, d, t in zip(axes, done, title): | |
| im = ax.imshow(d, vmin=0, vmax=1, cmap="Blues", aspect=d.shape[1]/d.shape[0]) | |
| ax.set_title(t) | |
| ax.set_yticks(np.arange(d.shape[0])) | |
| lables = ["Frame{}".format(i+1) for i in range(d.shape[0])] | |
| ax.set_yticklabels(lables) | |
| ax.set_yticks(np.arange(d.shape[0])-.5, minor=True) | |
| ax.grid(which="minor", color="g", linestyle='-.', linewidth=1) | |
| ax.invert_yaxis() | |
| return __draw(fig, file_name, plot_dir) | |