Spaces:
Running
on
Zero
Running
on
Zero
| from diffusers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import LambdaLR | |
| import math | |
| from DiT_VAE.diffusion.utils.logger import get_root_logger | |
| def build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio): | |
| if not config.get('lr_schedule_args', None): | |
| config.lr_schedule_args = {} | |
| if config.get('lr_warmup_steps', None): | |
| config['num_warmup_steps'] = config.get('lr_warmup_steps') # for compatibility with old version | |
| logger = get_root_logger() | |
| logger.info( | |
| f'Lr schedule: {config.lr_schedule}, ' + ",".join( | |
| [f"{key}:{value}" for key, value in config.lr_schedule_args.items()]) + '.') | |
| if config.lr_schedule == 'cosine': | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer=optimizer, | |
| **config.lr_schedule_args, | |
| num_training_steps=(len(train_dataloader) * config.num_epochs), | |
| ) | |
| elif config.lr_schedule == 'constant': | |
| lr_scheduler = get_constant_schedule_with_warmup( | |
| optimizer=optimizer, | |
| **config.lr_schedule_args, | |
| ) | |
| elif config.lr_schedule == 'cosine_decay_to_constant': | |
| assert lr_scale_ratio >= 1 | |
| lr_scheduler = get_cosine_decay_to_constant_with_warmup( | |
| optimizer=optimizer, | |
| **config.lr_schedule_args, | |
| final_lr=1 / lr_scale_ratio, | |
| num_training_steps=(len(train_dataloader) * config.num_epochs), | |
| ) | |
| else: | |
| raise RuntimeError(f'Unrecognized lr schedule {config.lr_schedule}.') | |
| return lr_scheduler | |
| def get_cosine_decay_to_constant_with_warmup(optimizer: Optimizer, | |
| num_warmup_steps: int, | |
| num_training_steps: int, | |
| final_lr: float = 0.0, | |
| num_decay: float = 0.667, | |
| num_cycles: float = 0.5, | |
| last_epoch: int = -1 | |
| ): | |
| """ | |
| Create a schedule with a cosine annealing lr followed by a constant lr. | |
| Args: | |
| optimizer ([`~torch.optim.Optimizer`]): | |
| The optimizer for which to schedule the learning rate. | |
| num_warmup_steps (`int`): | |
| The number of steps for the warmup phase. | |
| num_training_steps (`int`): | |
| The number of total training steps. | |
| final_lr (`int`): | |
| The final constant lr after cosine decay. | |
| num_decay (`int`): | |
| The | |
| last_epoch (`int`, *optional*, defaults to -1): | |
| The index of the last epoch when resuming training. | |
| Return: | |
| `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | |
| """ | |
| def lr_lambda(current_step): | |
| if current_step < num_warmup_steps: | |
| return float(current_step) / float(max(1, num_warmup_steps)) | |
| num_decay_steps = int(num_training_steps * num_decay) | |
| if current_step > num_decay_steps: | |
| return final_lr | |
| progress = float(current_step - num_warmup_steps) / float(max(1, num_decay_steps - num_warmup_steps)) | |
| return ( | |
| max( | |
| 0.0, | |
| 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)), | |
| ) | |
| * (1 - final_lr) | |
| ) + final_lr | |
| return LambdaLR(optimizer, lr_lambda, last_epoch) | |