|
import argparse |
|
import os |
|
import cv2 |
|
import glob |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from typing import Dict, Optional, List |
|
from omegaconf import OmegaConf, DictConfig |
|
from PIL import Image |
|
from pathlib import Path |
|
from dataclasses import dataclass |
|
from typing import Dict |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
import torchvision.transforms.functional as TF |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from torchvision.utils import make_grid, save_image |
|
from accelerate.utils import set_seed |
|
from tqdm.auto import tqdm |
|
from einops import rearrange, repeat |
|
from multiview.pipeline_multiclass import StableUnCLIPImg2ImgPipeline |
|
|
|
weight_dtype = torch.float16 |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def tensor_to_numpy(tensor): |
|
return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
|
|
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" |
|
|
|
def nonzero_normalize_depth(depth, mask=None): |
|
if mask.max() > 0: |
|
nonzero_depth_min = depth[mask > 0].min() |
|
else: |
|
nonzero_depth_min = 0 |
|
depth = (depth - nonzero_depth_min) / depth.max() |
|
return np.clip(depth, 0, 1) |
|
|
|
|
|
class SingleImageData(Dataset): |
|
def __init__(self, |
|
input_dir, |
|
prompt_embeds_path='./multiview/fixed_prompt_embeds_6view', |
|
image_transforms=[], |
|
total_views=6, |
|
ext="png", |
|
return_paths=True, |
|
) -> None: |
|
"""Create a dataset from a folder of images. |
|
If you pass in a root directory it will be searched for images |
|
ending in ext (ext can be a list) |
|
""" |
|
self.input_dir = Path(input_dir) |
|
self.return_paths = return_paths |
|
self.total_views = total_views |
|
|
|
self.paths = glob.glob(str(self.input_dir / f'*.{ext}')) |
|
|
|
print('============= length of dataset %d =============' % len(self.paths)) |
|
self.tform = image_transforms |
|
self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') |
|
self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') |
|
|
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
|
|
def load_rgb(self, path, color): |
|
img = plt.imread(path) |
|
img = Image.fromarray(np.uint8(img * 255.)) |
|
new_img = Image.new("RGB", (1024, 1024)) |
|
|
|
width, height = img.size |
|
new_width = int(width / height * 1024) |
|
img = img.resize((new_width, 1024)) |
|
new_img.paste((255, 255, 255), (0, 0, 1024, 1024)) |
|
offset = (1024 - new_width) // 2 |
|
new_img.paste(img, (offset, 0)) |
|
return new_img |
|
|
|
def __getitem__(self, index): |
|
data = {} |
|
filename = self.paths[index] |
|
|
|
if self.return_paths: |
|
data["path"] = str(filename) |
|
color = 1.0 |
|
cond_im_rgb = self.process_im(self.load_rgb(filename, color)) |
|
cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0) |
|
|
|
data["image_cond_rgb"] = cond_im_rgb |
|
data["normal_prompt_embeddings"] = self.normal_text_embeds |
|
data["color_prompt_embeddings"] = self.color_text_embeds |
|
data["filename"] = filename.split('/')[-1] |
|
|
|
return data |
|
|
|
def process_im(self, im): |
|
im = im.convert("RGB") |
|
return self.tform(im) |
|
|
|
def tensor_to_image(self, tensor): |
|
return Image.fromarray(np.uint8(tensor.numpy() * 255.)) |
|
|
|
|
|
@dataclass |
|
class TestConfig: |
|
pretrained_model_name_or_path: str |
|
pretrained_unet_path:Optional[str] |
|
revision: Optional[str] |
|
validation_dataset: Dict |
|
save_dir: str |
|
seed: Optional[int] |
|
validation_batch_size: int |
|
dataloader_num_workers: int |
|
save_mode: str |
|
local_rank: int |
|
|
|
pipe_kwargs: Dict |
|
pipe_validation_kwargs: Dict |
|
unet_from_pretrained_kwargs: Dict |
|
validation_grid_nrow: int |
|
camera_embedding_lr_mult: float |
|
|
|
num_views: int |
|
camera_embedding_type: str |
|
|
|
pred_type: str |
|
regress_elevation: bool |
|
enable_xformers_memory_efficient_attention: bool |
|
|
|
cond_on_normals: bool |
|
cond_on_colors: bool |
|
|
|
regress_elevation: bool |
|
regress_focal_length: bool |
|
|
|
|
|
|
|
def convert_to_numpy(tensor): |
|
return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
|
def save_image(tensor, fp): |
|
ndarr = convert_to_numpy(tensor) |
|
save_image_numpy(ndarr, fp) |
|
return ndarr |
|
|
|
def save_image_numpy(ndarr, fp): |
|
im = Image.fromarray(ndarr) |
|
|
|
if im.size[0] != im.size[1]: |
|
size = max(im.size) |
|
new_im = Image.new("RGB", (size, size)) |
|
|
|
new_im.paste((255, 255, 255), (0, 0, size, size)) |
|
new_im.paste(im, ((size - im.size[0]) // 2, (size - im.size[1]) // 2)) |
|
im = new_im |
|
|
|
im = im.resize((1024, 1024), Image.LANCZOS) |
|
im.save(fp) |
|
|
|
def run_multiview_infer(dataloader, pipeline, cfg: TestConfig, save_dir, num_levels=3): |
|
if cfg.seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed) |
|
|
|
images_cond = [] |
|
for _, batch in tqdm(enumerate(dataloader)): |
|
torch.cuda.empty_cache() |
|
images_cond.append(batch['image_cond_rgb'][:, 0].cuda()) |
|
imgs_in = torch.cat([batch['image_cond_rgb']]*2, dim=0).cuda() |
|
num_views = imgs_in.shape[1] |
|
imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") |
|
|
|
target_h, target_w = imgs_in.shape[-2], imgs_in.shape[-1] |
|
|
|
normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'].cuda(), batch['color_prompt_embeddings'].cuda() |
|
prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) |
|
prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") |
|
|
|
|
|
unet_out = pipeline( |
|
imgs_in, None, prompt_embeds=prompt_embeddings, |
|
generator=generator, guidance_scale=3.0, output_type='pt', num_images_per_prompt=1, |
|
height=cfg.height, width=cfg.width, |
|
num_inference_steps=40, eta=1.0, |
|
num_levels=num_levels, |
|
) |
|
|
|
for level in range(num_levels): |
|
out = unet_out[level].images |
|
bsz = out.shape[0] // 2 |
|
|
|
normals_pred = out[:bsz] |
|
images_pred = out[bsz:] |
|
|
|
cur_dir = save_dir |
|
os.makedirs(cur_dir, exist_ok=True) |
|
|
|
for i in range(bsz//num_views): |
|
scene = batch['filename'][i].split('.')[0] |
|
scene_dir = os.path.join(cur_dir, scene, f'level{level}') |
|
os.makedirs(scene_dir, exist_ok=True) |
|
|
|
img_in_ = images_cond[-1][i].to(out.device) |
|
for j in range(num_views): |
|
view = VIEWS[j] |
|
idx = i*num_views + j |
|
normal = normals_pred[idx] |
|
color = images_pred[idx] |
|
|
|
|
|
normal_filename = f"normal_{j}.png" |
|
rgb_filename = f"color_{j}.png" |
|
save_image(normal, os.path.join(scene_dir, normal_filename)) |
|
save_image(color, os.path.join(scene_dir, rgb_filename)) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
def load_multiview_pipeline(cfg): |
|
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained( |
|
cfg.pretrained_path, |
|
torch_dtype=torch.float16,) |
|
pipeline.unet.enable_xformers_memory_efficient_attention() |
|
if torch.cuda.is_available(): |
|
pipeline.to(device) |
|
return pipeline |
|
|
|
def main( |
|
cfg: TestConfig |
|
): |
|
set_seed(cfg.seed) |
|
pipeline = load_multiview_pipeline(cfg) |
|
if torch.cuda.is_available(): |
|
pipeline.to(device) |
|
|
|
image_transforms = [transforms.Resize(int(max(cfg.height, cfg.width))), |
|
transforms.CenterCrop((cfg.height, cfg.width)), |
|
transforms.ToTensor(), |
|
transforms.Lambda(lambda x: x * 2. - 1), |
|
] |
|
image_transforms = transforms.Compose(image_transforms) |
|
dataset = SingleImageData(image_transforms=image_transforms, input_dir=cfg.input_dir, total_views=cfg.num_views) |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, batch_size=1, shuffle=False, num_workers=1 |
|
) |
|
os.makedirs(cfg.output_dir, exist_ok=True) |
|
|
|
with torch.no_grad(): |
|
run_multiview_infer(dataloader, pipeline, cfg, cfg.output_dir, num_levels=cfg.num_levels) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--seed", type=int, default=42) |
|
parser.add_argument("--num_views", type=int, default=6) |
|
parser.add_argument("--num_levels", type=int, default=3) |
|
parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024') |
|
parser.add_argument("--height", type=int, default=1024) |
|
parser.add_argument("--width", type=int, default=576) |
|
parser.add_argument("--input_dir", type=str, default='./result/apose') |
|
parser.add_argument("--output_dir", type=str, default='./result/multiview') |
|
cfg = parser.parse_args() |
|
|
|
if cfg.num_views == 6: |
|
VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] |
|
else: |
|
raise NotImplementedError(f"Number of views {cfg.num_views} not supported") |
|
main(cfg) |
|
|