|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import glob |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
import torch |
|
|
|
|
|
def intersperse(lst, item): |
|
|
|
result = [item] * (len(lst) * 2 + 1) |
|
result[1::2] = lst |
|
return result |
|
|
|
|
|
def parse_filelist(filelist_path, split_char="|"): |
|
with open(filelist_path, encoding='utf-8') as f: |
|
filepaths_and_text = [line.strip().split(split_char) for line in f] |
|
return filepaths_and_text |
|
|
|
|
|
def latest_checkpoint_path(dir_path, regex="grad_*.pt"): |
|
f_list = glob.glob(os.path.join(dir_path, regex)) |
|
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) |
|
x = f_list[-1] |
|
return x |
|
|
|
|
|
def load_checkpoint(logdir, model, num=None): |
|
if num is None: |
|
model_path = latest_checkpoint_path(logdir, regex="grad_*.pt") |
|
else: |
|
model_path = os.path.join(logdir, f"grad_{num}.pt") |
|
print(f'Loading checkpoint {model_path}...') |
|
model_dict = torch.load(model_path, map_location=lambda loc, storage: loc) |
|
model.load_state_dict(model_dict, strict=False) |
|
return model |
|
|
|
|
|
def save_figure_to_numpy(fig): |
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') |
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
return data |
|
|
|
|
|
def plot_tensor(tensor): |
|
plt.style.use('default') |
|
fig, ax = plt.subplots(figsize=(12, 3)) |
|
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') |
|
plt.colorbar(im, ax=ax) |
|
plt.tight_layout() |
|
fig.canvas.draw() |
|
data = save_figure_to_numpy(fig) |
|
plt.close() |
|
return data |
|
|
|
|
|
def save_plot(tensor, savepath): |
|
plt.style.use('default') |
|
fig, ax = plt.subplots(figsize=(12, 3)) |
|
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') |
|
plt.colorbar(im, ax=ax) |
|
plt.tight_layout() |
|
fig.canvas.draw() |
|
plt.savefig(savepath) |
|
plt.close() |
|
return |
|
|
|
def save_plot_f0(tensor, savepath): |
|
|
|
x = np.arange(0, tensor.shape[0], 1) |
|
plt.style.use('default') |
|
fig, ax = plt.subplots(figsize=(12, 3)) |
|
ax.plot(x,tensor) |
|
plt.tight_layout() |
|
fig.canvas.draw() |
|
plt.savefig(savepath) |
|
plt.close() |
|
return |