Spaces:
Paused
Paused
import os | |
def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10): | |
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" | |
lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
class Save_Handle(object): | |
"""handle the number of """ | |
def __init__(self, max_num): | |
self.save_list = [] | |
self.max_num = max_num | |
def append(self, save_path): | |
if len(self.save_list) < self.max_num: | |
self.save_list.append(save_path) | |
else: | |
remove_path = self.save_list[0] | |
del self.save_list[0] | |
self.save_list.append(save_path) | |
if os.path.exists(remove_path): | |
os.remove(remove_path) | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = 1.0 * self.sum / self.count | |
def get_avg(self): | |
return self.avg | |
def get_count(self): | |
return self.count | |
def set_trainable(model, requires_grad): | |
for param in model.parameters(): | |
param.requires_grad = requires_grad | |
def get_num_params(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) |