import os import glob import numpy as np import matplotlib.pyplot as plt import torch def parse_filelist(filelist_path, split_char="|"): with open(filelist_path, encoding='utf-8') as f: filepaths_and_text = [line.strip().split(split_char) for line in f] return filepaths_and_text def load_model(model, saved_state_dict): state_dict = model.state_dict() new_state_dict = {} for k, v in state_dict.items(): try: new_state_dict[k] = saved_state_dict[k] except: print("%s is not in the checkpoint" % k) new_state_dict[k] = v model.load_state_dict(new_state_dict) return model def latest_checkpoint_path(dir_path, regex="grad_svc_*.pt"): f_list = glob.glob(os.path.join(dir_path, regex)) f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) x = f_list[-1] return x def load_checkpoint(logdir, model, num=None): if num is None: model_path = latest_checkpoint_path(logdir, regex="grad_svc_*.pt") else: model_path = os.path.join(logdir, f"grad_svc_{num}.pt") print(f'Loading checkpoint {model_path}...') model_dict = torch.load(model_path, map_location=lambda loc, storage: loc) model.load_state_dict(model_dict, strict=False) return model def save_figure_to_numpy(fig): data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data def plot_tensor(tensor): plt.style.use('default') fig, ax = plt.subplots(figsize=(12, 3)) im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') plt.colorbar(im, ax=ax) plt.tight_layout() fig.canvas.draw() data = save_figure_to_numpy(fig) plt.close() return data def save_plot(tensor, savepath): plt.style.use('default') fig, ax = plt.subplots(figsize=(12, 3)) im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') plt.colorbar(im, ax=ax) plt.tight_layout() fig.canvas.draw() plt.savefig(savepath) plt.close() return def print_error(info): print(f"\033[31m {info} \033[0m")