|
|
import argparse |
|
|
import os |
|
|
import yaml |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import torchdiffeq |
|
|
import utils |
|
|
from diff2flow import VPDiffusionFlow, dict2namespace |
|
|
import datasets |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def ode_inverse_solve( |
|
|
flow_model, |
|
|
x_data, |
|
|
x_cond, |
|
|
steps=100, |
|
|
method="dopri5", |
|
|
patch_size=64, |
|
|
atol=1e-5, |
|
|
rtol=1e-5, |
|
|
): |
|
|
""" |
|
|
Solves the ODE from t=0 (data) to t=1 (noise). |
|
|
Returns x_1 (noise latent). |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def drift_func(t, x): |
|
|
|
|
|
|
|
|
return flow_model.get_velocity(x, t, x_cond, patch_size=patch_size) |
|
|
|
|
|
|
|
|
t_eval = torch.linspace(0.0, 1.0, steps + 1, device=x_data.device) |
|
|
|
|
|
|
|
|
out = torchdiffeq.odeint( |
|
|
drift_func, x_data, t_eval, method=method, atol=atol, rtol=rtol |
|
|
) |
|
|
|
|
|
return out[-1] |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, required=True) |
|
|
parser.add_argument("--resume", type=str, required=True) |
|
|
parser.add_argument("--data_dir", type=str, default=None) |
|
|
parser.add_argument("--dataset", type=str, default=None) |
|
|
parser.add_argument("--steps", type=int, default=100) |
|
|
parser.add_argument("--output_dir", type=str, default="reflow_data") |
|
|
parser.add_argument("--seed", type=int, default=61) |
|
|
parser.add_argument("--patch_size", type=int, default=64) |
|
|
parser.add_argument("--method", type=str, default="dopri5") |
|
|
parser.add_argument("--atol", type=float, default=1e-5) |
|
|
parser.add_argument("--rtol", type=float, default=1e-5) |
|
|
parser.add_argument( |
|
|
"--max_images", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Max images to generate (for testing)", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
with open(os.path.join("configs", args.config), "r") as f: |
|
|
config_dict = yaml.safe_load(f) |
|
|
config = dict2namespace(config_dict) |
|
|
|
|
|
if args.data_dir: |
|
|
config.data.data_dir = args.data_dir |
|
|
if args.dataset: |
|
|
config.data.dataset = args.dataset |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
config.device = device |
|
|
|
|
|
|
|
|
torch.manual_seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
|
|
|
|
|
|
print("Initializing VPDiffusionFlow...") |
|
|
flow = VPDiffusionFlow(args, config) |
|
|
flow.load_ckpt(args.resume) |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"Loading dataset {config.data.dataset}...") |
|
|
DATASET = datasets.__dict__[config.data.dataset](config) |
|
|
|
|
|
|
|
|
train_loader, _ = DATASET.get_loaders( |
|
|
parse_patches=False, |
|
|
validation=config.data.dataset if args.dataset else "raindrop", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_loader, _ = DATASET.get_loaders(parse_patches=True) |
|
|
|
|
|
print(f"Starting generation of reflow pairs...") |
|
|
|
|
|
count = 0 |
|
|
|
|
|
|
|
|
for i, (x_batch, img_id) in enumerate( |
|
|
tqdm(train_loader, desc="Generating Reflow Pairs") |
|
|
): |
|
|
|
|
|
|
|
|
if x_batch.ndim == 5: |
|
|
x_batch = x_batch.flatten(start_dim=0, end_dim=1) |
|
|
|
|
|
input_img = x_batch[:, :3, :, :].to(device) |
|
|
gt_img = x_batch[:, 3:, :, :].to(device) |
|
|
|
|
|
|
|
|
x_cond = utils.sampling.data_transform(input_img) |
|
|
x_data = utils.sampling.data_transform(gt_img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
x_noise = ode_inverse_solve( |
|
|
flow, |
|
|
x_data, |
|
|
x_cond, |
|
|
steps=args.steps, |
|
|
method=args.method, |
|
|
patch_size=args.patch_size, |
|
|
atol=args.atol, |
|
|
rtol=args.rtol, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_data = { |
|
|
"x_noise": x_noise.cpu(), |
|
|
"x_data": x_data.cpu(), |
|
|
"x_cond": x_cond.cpu(), |
|
|
} |
|
|
|
|
|
save_path = os.path.join(args.output_dir, f"batch_{i}.pth") |
|
|
torch.save(batch_data, save_path) |
|
|
|
|
|
print(f"Saved batch {i} to {save_path}") |
|
|
|
|
|
count += input_img.shape[0] |
|
|
if args.max_images and count >= args.max_images: |
|
|
print(f"Reached max images {args.max_images}") |
|
|
break |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|