code_SAS_VLM2Vec / src /token_prune_scheduler.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
# src/token_prune_scheduler.py
import math
class TokenPruneScheduler:
def __init__(self, args):
self.args = args
def rate(self, global_step: int, max_steps: int) -> float:
# warmup: 不压缩
if self.args.token_prune_warmup_steps and global_step < self.args.token_prune_warmup_steps:
return 0.0
sche = self.args.token_prune_schedule
if sche == "fixed":
return float(self.args.token_prune_rate)
total = self.args.token_prune_schedule_steps if self.args.token_prune_schedule_steps > 0 else max_steps
progress = min(1.0, max(0.0, (global_step - self.args.token_prune_warmup_steps) / max(1, total)))
r_min = float(self.args.token_prune_min_rate)
r_max = float(self.args.token_prune_max_rate)
if sche == "linear-increase":
return r_min + (r_max - r_min) * progress
if sche == "linear-decay":
return r_max - (r_max - r_min) * progress
if sche == "cosine":
return r_min + (r_max - r_min) * 0.5 * (1 - math.cos(math.pi * progress))
return 0.0