s12 / utils /find_LR.py
srikanthp07's picture
Upload 27 files
9022436
raw
history blame contribute delete
No virus
397 Bytes
from torch_lr_finder import LRFinder
def find_lr(model,optimizer, criterion, device,train_loader):
lr_finder = LRFinder(model, optimizer, criterion, device=device)
lr_finder.range_test(
train_loader,
step_mode="exp",
end_lr=10,
num_iter=200,
)
mx_lr = lr_finder.plot(suggest_lr=True, skip_start=0, skip_end=0)
lr_finder.reset()
return mx_lr