TreeFormer / utils /pytorch_utils.py
franciszzj's picture
init
c964d4c
import os
def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class Save_Handle(object):
"""handle the number of """
def __init__(self, max_num):
self.save_list = []
self.max_num = max_num
def append(self, save_path):
if len(self.save_list) < self.max_num:
self.save_list.append(save_path)
else:
remove_path = self.save_list[0]
del self.save_list[0]
self.save_list.append(save_path)
if os.path.exists(remove_path):
os.remove(remove_path)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = 1.0 * self.sum / self.count
def get_avg(self):
return self.avg
def get_count(self):
return self.count
def set_trainable(model, requires_grad):
for param in model.parameters():
param.requires_grad = requires_grad
def get_num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)