Spaces:
Runtime error
Runtime error
# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. | |
# This program is free software; you can redistribute it and/or modify | |
# it under the terms of the MIT License. | |
# This program is distributed in the hope that it will be useful, | |
# but WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
# MIT License for more details. | |
import os | |
import glob | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
def intersperse(lst, item): | |
# Adds blank symbol | |
result = [item] * (len(lst) * 2 + 1) | |
result[1::2] = lst | |
return result | |
def parse_filelist(filelist_path, split_char="|"): | |
with open(filelist_path, encoding='utf-8') as f: | |
filepaths_and_text = [line.strip().split(split_char) for line in f] | |
return filepaths_and_text | |
def latest_checkpoint_path(dir_path, regex="grad_*.pt"): | |
f_list = glob.glob(os.path.join(dir_path, regex)) | |
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) | |
x = f_list[-1] | |
return x | |
def load_checkpoint(logdir, model, num=None): | |
if num is None: | |
model_path = latest_checkpoint_path(logdir, regex="grad_*.pt") | |
else: | |
model_path = os.path.join(logdir, f"grad_{num}.pt") | |
print(f'Loading checkpoint {model_path}...') | |
model_dict = torch.load(model_path, map_location=lambda loc, storage: loc) | |
model.load_state_dict(model_dict, strict=False) | |
return model | |
def save_figure_to_numpy(fig): | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
def plot_tensor(tensor): | |
plt.style.use('default') | |
fig, ax = plt.subplots(figsize=(12, 3)) | |
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
plt.colorbar(im, ax=ax) | |
plt.tight_layout() | |
fig.canvas.draw() | |
data = save_figure_to_numpy(fig) | |
plt.close() | |
return data | |
def save_plot(tensor, savepath): | |
plt.style.use('default') | |
fig, ax = plt.subplots(figsize=(12, 3)) | |
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
plt.colorbar(im, ax=ax) | |
plt.tight_layout() | |
fig.canvas.draw() | |
plt.savefig(savepath) | |
plt.close() | |
return | |