import numpy as np import os, sys, shutil import pickle import yaml, torch from datetime import datetime from easydict import EasyDict as edict from typing import Any, IO def sum_para_cnt(model): return sum([param.nelement() for param in model.parameters()]) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) def get_str(self): formatted_num = "{:.4f}".format(self.avg) return self.name+': ' + str(formatted_num) + '\t' def remove_prefix(state_dict): new_state_dict = {} for k, v in state_dict.items(): k = k.split('module.')[-1] new_state_dict[k] = v return new_state_dict