Spaces:
Sleeping
Sleeping
| import tqdm | |
| class LearnScheduleIterator: | |
| def __init__(self, learn_rate, max_steps, cur_step=0): | |
| """ | |
| specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 | |
| """ | |
| pairs = learn_rate.split(',') | |
| self.rates = [] | |
| self.it = 0 | |
| self.maxit = 0 | |
| try: | |
| for pair in pairs: | |
| if not pair.strip(): | |
| continue | |
| tmp = pair.split(':') | |
| if len(tmp) == 2: | |
| step = int(tmp[1]) | |
| if step > cur_step: | |
| self.rates.append((float(tmp[0]), min(step, max_steps))) | |
| self.maxit += 1 | |
| if step > max_steps: | |
| return | |
| elif step == -1: | |
| self.rates.append((float(tmp[0]), max_steps)) | |
| self.maxit += 1 | |
| return | |
| else: | |
| self.rates.append((float(tmp[0]), max_steps)) | |
| self.maxit += 1 | |
| return | |
| assert self.rates | |
| except (ValueError, AssertionError) as e: | |
| raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| if self.it < self.maxit: | |
| self.it += 1 | |
| return self.rates[self.it - 1] | |
| else: | |
| raise StopIteration | |
| class LearnRateScheduler: | |
| def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): | |
| self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) | |
| (self.learn_rate, self.end_step) = next(self.schedules) | |
| self.verbose = verbose | |
| if self.verbose: | |
| print(f'Training at rate of {self.learn_rate} until step {self.end_step}') | |
| self.finished = False | |
| def step(self, step_number): | |
| if step_number < self.end_step: | |
| return False | |
| try: | |
| (self.learn_rate, self.end_step) = next(self.schedules) | |
| except StopIteration: | |
| self.finished = True | |
| return False | |
| return True | |
| def apply(self, optimizer, step_number): | |
| if not self.step(step_number): | |
| return | |
| if self.verbose: | |
| tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') | |
| for pg in optimizer.param_groups: | |
| pg['lr'] = self.learn_rate | |