| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from torch import nn, Tensor |
| |
|
| |
|
| | def _bf16_u16(x: Tensor) -> Tensor: |
| | |
| | return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF |
| |
|
| |
|
| | class CachedDenoiseStepEmb(nn.Module): |
| | """bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong).""" |
| |
|
| | def __init__(self, base: nn.Module, sigmas: list[float]): |
| | super().__init__() |
| | device = next(base.parameters()).device |
| |
|
| | levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) |
| | bits = _bf16_u16(levels) |
| | if torch.unique(bits).numel() != bits.numel(): |
| | raise ValueError( |
| | "scheduler_sigmas collide in bf16; caching would be ambiguous" |
| | ) |
| |
|
| | with torch.no_grad(): |
| | table = ( |
| | base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous() |
| | ) |
| |
|
| | lut = torch.full((65536,), -1, device=device, dtype=torch.int32) |
| | lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32) |
| |
|
| | self.register_buffer("table", table, persistent=False) |
| | self.register_buffer("lut", lut, persistent=False) |
| | self.register_buffer( |
| | "oob", |
| | torch.tensor(bits.numel(), device=device, dtype=torch.int32), |
| | persistent=False, |
| | ) |
| |
|
| | def forward(self, sigma: Tensor) -> Tensor: |
| | if sigma.dtype is not torch.bfloat16: |
| | raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16") |
| | idx = self.lut[_bf16_u16(sigma)] |
| | idx = torch.where(idx >= 0, idx, self.oob) |
| | return self.table[idx.to(torch.int64)] |
| |
|
| |
|
| | class CachedCondHead(nn.Module): |
| | """bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong).""" |
| |
|
| | def __init__( |
| | self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8 |
| | ): |
| | super().__init__() |
| | table = cached_denoise_step_emb.table |
| | S, D = table.shape |
| |
|
| | with torch.no_grad(): |
| | emb = table[:, None, :] |
| | cache = ( |
| | torch.stack([t.squeeze(1) for t in base(emb)], 0) |
| | .to(torch.bfloat16) |
| | .contiguous() |
| | ) |
| |
|
| | |
| | key_dim = None |
| | for d in range(min(D, max_key_dims)): |
| | b = _bf16_u16(table[:, d]) |
| | if torch.unique(b).numel() == S: |
| | key_dim = d |
| | key_bits = b |
| | break |
| | if key_dim is None: |
| | raise ValueError( |
| | "Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims" |
| | ) |
| |
|
| | lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32) |
| | lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32) |
| |
|
| | self.key_dim = int(key_dim) |
| | self.register_buffer("cache", cache, persistent=False) |
| | self.register_buffer("lut", lut, persistent=False) |
| | self.register_buffer( |
| | "oob", |
| | torch.tensor(S, device=table.device, dtype=torch.int32), |
| | persistent=False, |
| | ) |
| |
|
| | def forward(self, cond: Tensor): |
| | if cond.dtype is not torch.bfloat16: |
| | raise RuntimeError("CachedCondHead expects cond bf16") |
| | idx = self.lut[_bf16_u16(cond[..., self.key_dim])] |
| | idx = torch.where(idx >= 0, idx, self.oob) |
| | g = self.cache[:, idx.to(torch.int64)] |
| | return tuple(g.unbind(0)) |
| |
|