Spaces:
Sleeping
Sleeping
import torch | |
import warnings | |
def get_fast_schedule(origial_timesteps, fast_after_steps, fast_rate): | |
if fast_after_steps >= len(origial_timesteps) - 1: | |
return origial_timesteps | |
new_timesteps = torch.cat((origial_timesteps[:fast_after_steps], origial_timesteps[fast_after_steps+1::fast_rate]), dim=0) | |
return new_timesteps | |
def dynamically_adjust_inference_steps(scheduler, index, t): | |
prev_t = scheduler.timesteps[index+1] if index+1 < len(scheduler.timesteps) else -1 | |
scheduler.num_inference_steps = scheduler.config.num_train_timesteps // (t - prev_t) | |
if index+1 < len(scheduler.timesteps): | |
if scheduler.config.num_train_timesteps // scheduler.num_inference_steps != t - prev_t: | |
warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) != ({t} - {prev_t}), so the step sizes may not be accurate") | |
else: | |
# as long as we hit final cumprob, it should be fine. | |
if scheduler.config.num_train_timesteps // scheduler.num_inference_steps > t - prev_t: | |
warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) > ({t} - {prev_t}), so the step sizes may not be accurate") | |