File size: 1,215 Bytes
d871568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import warnings

def get_fast_schedule(origial_timesteps, fast_after_steps, fast_rate):
    if fast_after_steps >= len(origial_timesteps) - 1:
        return origial_timesteps
    new_timesteps = torch.cat((origial_timesteps[:fast_after_steps], origial_timesteps[fast_after_steps+1::fast_rate]), dim=0)
    return new_timesteps

def dynamically_adjust_inference_steps(scheduler, index, t):
    prev_t = scheduler.timesteps[index+1] if index+1 < len(scheduler.timesteps) else -1
    scheduler.num_inference_steps = scheduler.config.num_train_timesteps // (t - prev_t)
    if index+1 < len(scheduler.timesteps):
        if scheduler.config.num_train_timesteps // scheduler.num_inference_steps != t - prev_t:
            warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) != ({t} - {prev_t}), so the step sizes may not be accurate")
    else:
        # as long as we hit final cumprob, it should be fine.
        if scheduler.config.num_train_timesteps // scheduler.num_inference_steps > t - prev_t:
            warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) > ({t} - {prev_t}), so the step sizes may not be accurate")