|
|
|
from typing import Callable, List, Optional |
|
|
|
import numpy as np |
|
|
|
|
|
def ordered_halving(val): |
|
bin_str = f"{val:064b}" |
|
bin_flip = bin_str[::-1] |
|
as_int = int(bin_flip, 2) |
|
|
|
return as_int / (1 << 64) |
|
|
|
|
|
def uniform( |
|
step: int = ..., |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
if num_frames <= context_size: |
|
yield list(range(num_frames)) |
|
return |
|
|
|
context_stride = min( |
|
context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 |
|
) |
|
|
|
for context_step in 1 << np.arange(context_stride): |
|
pad = int(round(num_frames * ordered_halving(step))) |
|
for j in range( |
|
int(ordered_halving(step) * context_step) + pad, |
|
num_frames + pad + (0 if closed_loop else -context_overlap), |
|
(context_size * context_step - context_overlap), |
|
): |
|
yield [ |
|
e % num_frames |
|
for e in range(j, j + context_size * context_step, context_step) |
|
] |
|
|
|
|
|
def get_context_scheduler(name: str) -> Callable: |
|
if name == "uniform": |
|
return uniform |
|
else: |
|
raise ValueError(f"Unknown context_overlap policy {name}") |
|
|
|
|
|
def get_total_steps( |
|
scheduler, |
|
timesteps: List[int], |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
return sum( |
|
len( |
|
list( |
|
scheduler( |
|
i, |
|
num_steps, |
|
num_frames, |
|
context_size, |
|
context_stride, |
|
context_overlap, |
|
) |
|
) |
|
) |
|
for i in range(len(timesteps)) |
|
) |
|
|