| # src/token_prune_scheduler.py | |
| import math | |
| class TokenPruneScheduler: | |
| def __init__(self, args): | |
| self.args = args | |
| def rate(self, global_step: int, max_steps: int) -> float: | |
| # warmup: 不压缩 | |
| if self.args.token_prune_warmup_steps and global_step < self.args.token_prune_warmup_steps: | |
| return 0.0 | |
| sche = self.args.token_prune_schedule | |
| if sche == "fixed": | |
| return float(self.args.token_prune_rate) | |
| total = self.args.token_prune_schedule_steps if self.args.token_prune_schedule_steps > 0 else max_steps | |
| progress = min(1.0, max(0.0, (global_step - self.args.token_prune_warmup_steps) / max(1, total))) | |
| r_min = float(self.args.token_prune_min_rate) | |
| r_max = float(self.args.token_prune_max_rate) | |
| if sche == "linear-increase": | |
| return r_min + (r_max - r_min) * progress | |
| if sche == "linear-decay": | |
| return r_max - (r_max - r_min) * progress | |
| if sche == "cosine": | |
| return r_min + (r_max - r_min) * 0.5 * (1 - math.cos(math.pi * progress)) | |
| return 0.0 |