|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Impl of multistep methods to solve the ODE in the diffusion model. |
|
|
""" |
|
|
|
|
|
from typing import Callable, List, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
from cosmos_predict1.diffusion.functional.runge_kutta import reg_x0_euler_step, res_x0_rk2_step |
|
|
|
|
|
|
|
|
def order2_fn( |
|
|
x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
|
""" |
|
|
impl the second order multistep method in https://arxiv.org/pdf/2308.02157 |
|
|
Adams Bashforth approach! |
|
|
""" |
|
|
if x0_preds: |
|
|
x0_s1, s1 = x0_preds[0] |
|
|
x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) |
|
|
else: |
|
|
x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] |
|
|
return x_t, [(x0_s, s)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MULTISTEP_FNs = { |
|
|
"2ab": order2_fn, |
|
|
} |
|
|
|
|
|
|
|
|
def get_multi_step_fn(name: str) -> Callable: |
|
|
if name in MULTISTEP_FNs: |
|
|
return MULTISTEP_FNs[name] |
|
|
methods = "\n\t".join(MULTISTEP_FNs.keys()) |
|
|
raise RuntimeError("Only support multistep method\n" + methods) |
|
|
|
|
|
|
|
|
def is_multi_step_fn_supported(name: str) -> bool: |
|
|
""" |
|
|
Check if the multistep method is supported. |
|
|
""" |
|
|
return name in MULTISTEP_FNs |
|
|
|