workshop / USDRL /tools.py
qiushuocheng's picture
Upload 117 files
5de1792
import numpy as np
import os, sys, shutil
import pickle
import yaml, torch
from datetime import datetime
from easydict import EasyDict as edict
from typing import Any, IO
def sum_para_cnt(model):
return sum([param.nelement() for param in model.parameters()])
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def get_str(self):
formatted_num = "{:.4f}".format(self.avg)
return self.name+': ' + str(formatted_num) + '\t'
def remove_prefix(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
k = k.split('module.')[-1]
new_state_dict[k] = v
return new_state_dict