from typing import Callable, Optional import numpy as np # Whatever this is, it's utterly cursed. def ordered_halving(val): bin_str = f"{val:064b}" bin_flip = bin_str[::-1] as_int = int(bin_flip, 2) return as_int / (1 << 64) # I have absolutely no idea how this works and I don't like that. def uniform( step: int = ..., num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): if num_frames <= context_size: yield list(range(num_frames)) return context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) for context_step in 1 << np.arange(context_stride): pad = int(round(num_frames * ordered_halving(step))) for j in range( int(ordered_halving(step) * context_step) + pad, num_frames + pad + (0 if closed_loop else -context_overlap), (context_size * context_step - context_overlap), ): yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] def shuffle( step: int = ..., num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): import random c = list(range(num_frames)) c = random.sample(c, len(c)) if len(c) % context_size: c += c[0:context_size - len(c) % context_size] c = random.sample(c, len(c)) for i in range(0, len(c), context_size): yield c[i:i+context_size] def composite( step: int = ..., num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): if (step/num_steps) < 0.1: return shuffle(step,num_steps,num_frames,context_size,context_stride,context_overlap,closed_loop) else: return uniform(step,num_steps,num_frames,context_size,context_stride,context_overlap,closed_loop) def get_context_scheduler(name: str) -> Callable: match name: case "uniform": return uniform case "shuffle": return shuffle case "composite": return composite case _: raise ValueError(f"Unknown context_overlap policy {name}") def get_total_steps( scheduler, timesteps: list[int], num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): return sum( len( list( scheduler( i, num_steps, num_frames, context_size, context_stride, context_overlap, ) ) ) for i in range(len(timesteps)) )