import torch from einops import rearrange def isinstance_str(x: object, cls_name: str): """ Checks whether x has any class *named* cls_name in its ancestry. Doesn't require access to the class's implementation. Useful for patching! """ for _cls in x.__class__.__mro__: if _cls.__name__ == cls_name: return True return False def init_generator(device: torch.device, fallback: torch.Generator=None): """ Forks the current default random generator given device. """ if device.type == "cpu": return torch.Generator(device="cpu").set_state(torch.get_rng_state()) elif device.type == "cuda": return torch.Generator(device=device).set_state(torch.cuda.get_rng_state()) else: if fallback is None: return init_generator(torch.device("cpu")) else: return fallback def join_frame(x, fsize): """ Join multi-frame tokens """ x = rearrange(x, "(B F) N C -> B (F N) C", F=fsize) return x def split_frame(x, fsize): """ Split multi-frame tokens """ x = rearrange(x, "B (F N) C -> (B F) N C", F=fsize) return x def func_warper(funcs): """ Warp a function sequence """ def fn(x, **kwarg): for func in funcs: x = func(x, **kwarg) return x return fn def join_warper(fsize): def fn(x): x = join_frame(x, fsize) return x return fn def split_warper(fsize): def fn(x): x = split_frame(x, fsize) return x return fn