|
""" |
|
Schedulers Script ver: Feb 15th 17:00 |
|
|
|
puzzle_patch_scheduler is used to arrange patch size for multi-scale learning |
|
|
|
ref |
|
lr_scheduler from MAE code. |
|
https://github.com/facebookresearch/mae |
|
""" |
|
|
|
import math |
|
import random |
|
|
|
|
|
def factor(num): |
|
""" |
|
find factor of input num |
|
""" |
|
factors = [] |
|
for_times = int(math.sqrt(num)) |
|
for i in range(for_times + 1)[1:]: |
|
if num % i == 0: |
|
factors.append(i) |
|
t = int(num / i) |
|
if not t == i: |
|
factors.append(t) |
|
return factors |
|
|
|
|
|
def defactor(num_list, basic_num): |
|
array = [] |
|
for i in num_list: |
|
if i // basic_num * basic_num - i == 0: |
|
array.append(i) |
|
array.sort() |
|
return array |
|
|
|
|
|
def adjust_learning_rate(optimizer, epoch, args): |
|
""" |
|
Decay the learning rate with half-cycle cosine after warmup |
|
epoch,ok with float,to be more flexible, |
|
like: data_iter_step / len(data_loader) + epoch |
|
""" |
|
|
|
if epoch < args.warmup_epochs: |
|
lr = args.lr * epoch / args.warmup_epochs |
|
|
|
else: |
|
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ |
|
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) |
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
if "lr_scale" in param_group: |
|
param_group["lr"] = lr * param_group["lr_scale"] |
|
else: |
|
param_group["lr"] = lr |
|
return lr |
|
|
|
|
|
class patch_scheduler: |
|
""" |
|
this is used to drive the patch size by loss and epoch |
|
the patch list is automatically get |
|
""" |
|
|
|
def __init__(self, total_epoches=200, warmup_epochs=20, edge_size=384, basic_patch=16, strategy=None, |
|
threshold=3.0, reducing_factor=0.933, fix_patch_size=None, patch_size_jump=None): |
|
super().__init__() |
|
|
|
self.strategy = strategy |
|
|
|
self.total_epoches = total_epoches |
|
self.warmup_epochs = warmup_epochs |
|
|
|
|
|
self.patch_list = defactor(factor(edge_size), basic_patch) |
|
|
|
self.threshold = threshold |
|
self.reducing_factor = reducing_factor |
|
self.fix_patch_size = fix_patch_size |
|
|
|
|
|
if len(self.patch_list) > 1: |
|
self.patch_list = self.patch_list[:-1] |
|
|
|
|
|
if patch_size_jump == 'odd': |
|
jump_patch_list = self.patch_list[0::2] |
|
self.patch_list = jump_patch_list |
|
elif patch_size_jump == 'even': |
|
jump_patch_list = self.patch_list[1::2] |
|
|
|
temp_list = [self.patch_list[0]] |
|
temp_list.extend(jump_patch_list) |
|
self.patch_list = temp_list |
|
else: |
|
pass |
|
|
|
if self.strategy in ['reverse', 'loss_back', 'loss_hold']: |
|
self.patch_list.sort(reverse=True) |
|
|
|
if self.strategy is None or self.strategy == 'fixed': |
|
puzzle_patch_size = self.fix_patch_size or self.patch_list[0] |
|
print('patch_list:', puzzle_patch_size) |
|
else: |
|
print('patch_list:', self.patch_list) |
|
|
|
|
|
|
|
def __call__(self, epoch, loss=0.0): |
|
|
|
if self.strategy == 'linear' or self.strategy == 'reverse': |
|
if epoch < self.warmup_epochs: |
|
puzzle_patch_size = 32 |
|
else: |
|
puzzle_patch_size = self.patch_list[min(int((epoch - self.warmup_epochs) |
|
/ (self.total_epoches - self.warmup_epochs) |
|
* len(self.patch_list)), len(self.patch_list) - 1)] |
|
|
|
elif self.strategy == 'loop': |
|
|
|
group_size = int(self.threshold) |
|
|
|
if epoch < self.warmup_epochs: |
|
puzzle_patch_size = 32 |
|
else: |
|
group_idx = (epoch - self.warmup_epochs) % (len(self.patch_list) * group_size) |
|
puzzle_patch_size = self.patch_list[int(group_idx / group_size)] |
|
|
|
elif self.strategy == 'random': |
|
puzzle_patch_size = random.choice(self.patch_list) |
|
|
|
elif self.strategy == 'loss_back': |
|
if epoch < self.warmup_epochs: |
|
puzzle_patch_size = 32 |
|
else: |
|
if loss == 0.0: |
|
puzzle_patch_size = self.patch_list[min(int((epoch - self.warmup_epochs) |
|
/ (self.total_epoches - self.warmup_epochs) |
|
* len(self.patch_list)), len(self.patch_list) - 1)] |
|
|
|
elif loss < self.threshold: |
|
puzzle_patch_size = self.patch_list[min(max(int((epoch - self.warmup_epochs) |
|
/ (self.total_epoches - self.warmup_epochs) |
|
* len(self.patch_list)) + 1, 0), |
|
len(self.patch_list) - 1)] |
|
self.threshold *= self.reducing_factor |
|
else: |
|
puzzle_patch_size = self.patch_list[min(max(int((epoch - self.warmup_epochs) |
|
/ (self.total_epoches - self.warmup_epochs) |
|
* len(self.patch_list)) - 1, 0), |
|
len(self.patch_list) - 1)] |
|
|
|
elif self.strategy == 'loss_hold': |
|
if epoch < self.warmup_epochs: |
|
puzzle_patch_size = 32 |
|
else: |
|
if loss == 0.0: |
|
puzzle_patch_size = self.patch_list[min(int((epoch - self.warmup_epochs) |
|
/ (self.total_epoches - self.warmup_epochs) |
|
* len(self.patch_list)), len(self.patch_list) - 1)] |
|
|
|
elif loss < self.threshold: |
|
puzzle_patch_size = self.patch_list[min(max(int((epoch - self.warmup_epochs) |
|
/ (self.total_epoches - self.warmup_epochs) |
|
* len(self.patch_list)) + 1, 0), |
|
len(self.patch_list) - 1)] |
|
self.threshold *= self.reducing_factor |
|
else: |
|
puzzle_patch_size = self.patch_list[min(max(int((epoch - self.warmup_epochs) |
|
/ (self.total_epoches - self.warmup_epochs) |
|
* len(self.patch_list)), 0), |
|
len(self.patch_list) - 1)] |
|
|
|
else: |
|
|
|
puzzle_patch_size = self.fix_patch_size or self.patch_list[0] |
|
|
|
return puzzle_patch_size |
|
|
|
|
|
class ratio_scheduler: |
|
""" |
|
this is used to drive the fix position ratio by loss and epoch |
|
the ratio is control by ratio_floor_factor=0.5, upper_limit=0.9, lower_limit=0.2 |
|
""" |
|
def __init__(self, total_epoches=200, warmup_epochs=20, basic_ratio=0.25, strategy=None, fix_position_ratio=None, |
|
threshold=4.0, loss_reducing_factor=0.933, ratio_floor_factor=0.5, upper_limit=0.9, lower_limit=0.2): |
|
|
|
|
|
super().__init__() |
|
self.strategy = strategy |
|
|
|
self.total_epoches = total_epoches |
|
self.warmup_epochs = warmup_epochs |
|
|
|
self.basic_ratio = basic_ratio |
|
|
|
self.threshold = threshold |
|
self.loss_reducing_factor = loss_reducing_factor |
|
|
|
self.fix_position_ratio = fix_position_ratio |
|
|
|
self.upper_limit = upper_limit |
|
self.lower_limit = lower_limit |
|
self.ratio_floor_factor = ratio_floor_factor |
|
|
|
def __call__(self, epoch, loss=0.0): |
|
if self.strategy == 'ratio-decay' or self.strategy == 'decay': |
|
if epoch < self.warmup_epochs: |
|
fix_position_ratio = self.basic_ratio |
|
else: |
|
max_ratio = min(3 * self.basic_ratio, self.upper_limit) |
|
min_ratio = max(self.basic_ratio * self.ratio_floor_factor, self.lower_limit) |
|
|
|
fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs) |
|
- (epoch - self.warmup_epochs)) / |
|
(self.total_epoches - self.warmup_epochs) |
|
* max_ratio, min_ratio), max_ratio) |
|
|
|
elif self.strategy == 'loss_back': |
|
|
|
if epoch < self.warmup_epochs: |
|
fix_position_ratio = self.basic_ratio |
|
|
|
else: |
|
max_ratio = min(3 * self.basic_ratio, self.upper_limit) |
|
min_ratio = max(self.basic_ratio * self.ratio_floor_factor, self.lower_limit) |
|
if loss == 0.0: |
|
fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs) |
|
- (epoch - self.warmup_epochs)) / |
|
(self.total_epoches - self.warmup_epochs) |
|
* max_ratio, min_ratio), max_ratio) |
|
elif loss < self.threshold: |
|
fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs) |
|
- (epoch - self.warmup_epochs)) / |
|
(self.total_epoches - self.warmup_epochs) |
|
* max_ratio * 0.9, min_ratio), max_ratio) |
|
self.threshold *= self.loss_reducing_factor |
|
else: |
|
fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs) |
|
- (epoch - self.warmup_epochs)) / |
|
(self.total_epoches - self.warmup_epochs) |
|
* max_ratio * 1.1, min_ratio), max_ratio) |
|
|
|
elif self.strategy == 'loss_hold': |
|
|
|
if epoch < self.warmup_epochs: |
|
fix_position_ratio = self.basic_ratio |
|
|
|
else: |
|
max_ratio = min(3 * self.basic_ratio, self.upper_limit) |
|
min_ratio = max(self.basic_ratio * self.ratio_floor_factor, self.lower_limit) |
|
|
|
if loss == 0.0: |
|
fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs) |
|
- (epoch - self.warmup_epochs)) / |
|
(self.total_epoches - self.warmup_epochs) |
|
* max_ratio, min_ratio), max_ratio) |
|
elif loss < self.threshold: |
|
fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs) |
|
- (epoch - self.warmup_epochs)) / |
|
(self.total_epoches - self.warmup_epochs) |
|
* max_ratio * 0.9, min_ratio), max_ratio) |
|
self.threshold *= self.loss_reducing_factor |
|
else: |
|
fix_position_ratio = min(max(((self.total_epoches - self.warmup_epochs) |
|
- (epoch - self.warmup_epochs)) / |
|
(self.total_epoches - self.warmup_epochs) |
|
* max_ratio, min_ratio), max_ratio) |
|
|
|
else: |
|
fix_position_ratio = self.fix_position_ratio or self.basic_ratio |
|
|
|
return fix_position_ratio |
|
|
|
|
|
''' |
|
scheduler = puzzle_fix_position_ratio_scheduler(strategy='reverse') |
|
epoch = 102 |
|
fix_position_ratio = scheduler(epoch) |
|
print(fix_position_ratio) |
|
''' |
|
|