|
import argparse |
|
import datetime |
|
import logging |
|
import inspect |
|
import math |
|
import os |
|
from typing import Dict, Optional, Tuple, List |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
from dataclasses import dataclass |
|
from packaging import version |
|
import shutil |
|
from collections import defaultdict |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
import torchvision.transforms.functional as TF |
|
from torchvision.utils import make_grid, save_image |
|
|
|
import transformers |
|
import accelerate |
|
from accelerate import Accelerator |
|
from accelerate.logging import get_logger |
|
from accelerate.utils import ProjectConfiguration, set_seed |
|
|
|
import diffusers |
|
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
|
from diffusers.optimization import get_scheduler |
|
from diffusers.training_utils import EMAModel |
|
from diffusers.utils import check_min_version, deprecate, is_wandb_available |
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
|
from tqdm.auto import tqdm |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
|
|
from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel |
|
|
|
from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset |
|
|
|
from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline |
|
|
|
from einops import rearrange |
|
from rembg import remove |
|
import pdb |
|
|
|
weight_dtype = torch.float16 |
|
|
|
|
|
@dataclass |
|
class TestConfig: |
|
pretrained_model_name_or_path: str |
|
pretrained_unet_path:str |
|
revision: Optional[str] |
|
validation_dataset: Dict |
|
save_dir: str |
|
seed: Optional[int] |
|
validation_batch_size: int |
|
dataloader_num_workers: int |
|
|
|
local_rank: int |
|
|
|
pipe_kwargs: Dict |
|
pipe_validation_kwargs: Dict |
|
unet_from_pretrained_kwargs: Dict |
|
validation_guidance_scales: List[float] |
|
validation_grid_nrow: int |
|
camera_embedding_lr_mult: float |
|
|
|
num_views: int |
|
camera_embedding_type: str |
|
|
|
pred_type: str |
|
|
|
enable_xformers_memory_efficient_attention: bool |
|
|
|
cond_on_normals: bool |
|
cond_on_colors: bool |
|
|
|
|
|
def log_validation(dataloader, pipeline, cfg: TestConfig, weight_dtype, name, save_dir): |
|
|
|
|
|
pipeline.set_progress_bar_config(disable=True) |
|
|
|
if cfg.seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=pipeline.device).manual_seed(cfg.seed) |
|
|
|
images_cond, images_pred = [], defaultdict(list) |
|
for i, batch in tqdm(enumerate(dataloader)): |
|
|
|
imgs_in = batch['imgs_in'] |
|
alphas = batch['alphas'] |
|
|
|
camera_embeddings = batch['camera_embeddings'] |
|
filename = batch['filename'] |
|
|
|
bsz, num_views = imgs_in.shape[0], imgs_in.shape[1] |
|
|
|
imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") |
|
alphas = rearrange(alphas, "B Nv C H W -> (B Nv) C H W") |
|
|
|
camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce") |
|
|
|
images_cond.append(imgs_in) |
|
|
|
with torch.autocast("cuda"): |
|
|
|
for guidance_scale in cfg.validation_guidance_scales: |
|
out = pipeline( |
|
imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs |
|
).images |
|
images_pred[f"{name}-sample_cfg{guidance_scale:.1f}"].append(out) |
|
cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}") |
|
|
|
|
|
for i in range(bsz): |
|
scene = os.path.basename(filename[i]) |
|
print(scene) |
|
scene_dir = os.path.join(cur_dir, scene) |
|
outs_dir = os.path.join(scene_dir, "outs") |
|
masked_outs_dir = os.path.join(scene_dir, "masked_outs") |
|
os.makedirs(outs_dir, exist_ok=True) |
|
os.makedirs(masked_outs_dir, exist_ok=True) |
|
img_in = imgs_in[i*num_views] |
|
alpha = alphas[i*num_views] |
|
img_in = torch.cat([img_in, alpha], dim=0) |
|
save_image(img_in, os.path.join(scene_dir, scene+".png")) |
|
for j in range(num_views): |
|
view = VIEWS[j] |
|
idx = i*num_views + j |
|
pred = out[idx] |
|
|
|
|
|
out_filename = f"{cfg.pred_type}_000_{view}.png" |
|
pred = save_image(pred, os.path.join(outs_dir, out_filename)) |
|
|
|
rm_pred = remove(pred) |
|
|
|
save_image_numpy(rm_pred, os.path.join(scene_dir, out_filename)) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def save_image(tensor, fp): |
|
ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
|
im = Image.fromarray(ndarr) |
|
im.save(fp) |
|
return ndarr |
|
|
|
def save_image_numpy(ndarr, fp): |
|
im = Image.fromarray(ndarr) |
|
im.save(fp) |
|
|
|
def log_validation_joint(dataloader, pipeline, cfg: TestConfig, weight_dtype, name, save_dir): |
|
|
|
pipeline.set_progress_bar_config(disable=True) |
|
|
|
if cfg.seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=pipeline.device).manual_seed(cfg.seed) |
|
|
|
images_cond, normals_pred, images_pred = [], defaultdict(list), defaultdict(list) |
|
for i, batch in tqdm(enumerate(dataloader)): |
|
|
|
imgs_in = torch.cat([batch['imgs_in']]*2, dim=0) |
|
|
|
filename = batch['filename'] |
|
|
|
|
|
camera_embeddings = torch.cat([batch['camera_embeddings']]*2, dim=0) |
|
|
|
task_embeddings = torch.cat([batch['normal_task_embeddings'], batch['color_task_embeddings']], dim=0) |
|
|
|
camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1) |
|
|
|
|
|
imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") |
|
|
|
camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce") |
|
|
|
images_cond.append(imgs_in) |
|
num_views = len(VIEWS) |
|
with torch.autocast("cuda"): |
|
|
|
for guidance_scale in cfg.validation_guidance_scales: |
|
out = pipeline( |
|
imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs |
|
).images |
|
|
|
bsz = out.shape[0] // 2 |
|
normals_pred = out[:bsz] |
|
images_pred = out[bsz:] |
|
|
|
cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}") |
|
|
|
for i in range(bsz//num_views): |
|
scene = filename[i] |
|
scene_dir = os.path.join(cur_dir, scene) |
|
normal_dir = os.path.join(scene_dir, "normals") |
|
masked_colors_dir = os.path.join(scene_dir, "masked_colors") |
|
os.makedirs(normal_dir, exist_ok=True) |
|
os.makedirs(masked_colors_dir, exist_ok=True) |
|
for j in range(num_views): |
|
view = VIEWS[j] |
|
idx = i*num_views + j |
|
normal = normals_pred[idx] |
|
color = images_pred[idx] |
|
|
|
normal_filename = f"normals_000_{view}.png" |
|
rgb_filename = f"rgb_000_{view}.png" |
|
normal = save_image(normal, os.path.join(normal_dir, normal_filename)) |
|
color = save_image(color, os.path.join(scene_dir, rgb_filename)) |
|
|
|
rm_normal = remove(normal) |
|
rm_color = remove(color) |
|
|
|
save_image_numpy(rm_normal, os.path.join(scene_dir, normal_filename)) |
|
save_image_numpy(rm_color, os.path.join(masked_colors_dir, rgb_filename)) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def load_wonder3d_pipeline(cfg): |
|
|
|
pipeline = MVDiffusionImagePipeline.from_pretrained( |
|
cfg.pretrained_model_name_or_path, |
|
torch_dtype=weight_dtype |
|
) |
|
|
|
|
|
pipeline.unet.enable_xformers_memory_efficient_attention() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
pipeline.to('cuda:0') |
|
|
|
return pipeline |
|
|
|
|
|
def main( |
|
cfg: TestConfig |
|
): |
|
|
|
|
|
if cfg.seed is not None: |
|
set_seed(cfg.seed) |
|
|
|
pipeline = load_wonder3d_pipeline(cfg) |
|
|
|
if cfg.enable_xformers_memory_efficient_attention: |
|
if is_xformers_available(): |
|
import xformers |
|
|
|
xformers_version = version.parse(xformers.__version__) |
|
if xformers_version == version.parse("0.0.16"): |
|
print( |
|
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." |
|
) |
|
pipeline.unet.enable_xformers_memory_efficient_attention() |
|
print("use xformers.") |
|
else: |
|
raise ValueError("xformers is not available. Make sure it is installed correctly") |
|
|
|
|
|
validation_dataset = MVDiffusionDataset( |
|
**cfg.validation_dataset |
|
) |
|
|
|
|
|
|
|
validation_dataloader = torch.utils.data.DataLoader( |
|
validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers |
|
) |
|
|
|
|
|
os.makedirs(cfg.save_dir, exist_ok=True) |
|
|
|
if cfg.pred_type == 'joint': |
|
log_validation_joint( |
|
validation_dataloader, |
|
pipeline, |
|
cfg, |
|
weight_dtype, |
|
'validation', |
|
cfg.save_dir |
|
) |
|
else: |
|
log_validation( |
|
validation_dataloader, |
|
pipeline, |
|
cfg, |
|
weight_dtype, |
|
'validation', |
|
cfg.save_dir |
|
) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', type=str, required=True) |
|
args, extras = parser.parse_known_args() |
|
|
|
from utils.misc import load_config |
|
|
|
|
|
cfg = load_config(args.config, cli_args=extras) |
|
print(cfg) |
|
schema = OmegaConf.structured(TestConfig) |
|
|
|
cfg = OmegaConf.merge(schema, cfg) |
|
|
|
if cfg.num_views == 6: |
|
VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] |
|
elif cfg.num_views == 4: |
|
VIEWS = ['front', 'right', 'back', 'left'] |
|
main(cfg) |
|
|