| """ |
| Lingteng Qiu |
| Baseline I2Normal to show the difference demo. |
| """ |
|
|
| import sys |
|
|
| sys.path.append("./") |
|
|
| import cv2 |
| import einops |
| import numpy as np |
| import torch |
| from tqdm import tqdm |
| import matplotlib.pyplot as plt |
| import os |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| import json |
| import pdb |
| import argparse |
| import tqlt |
| import random |
| from tqlt import utils as tu |
| from tqlt import op as tqlo |
| from human_generate_system.engineer.NormalEstimator.data_utils import ( |
| HWC3, |
| resize_image, |
| norm_normalize, |
| center_crop, |
| flip_x, |
| ) |
| from PIL import Image |
| from os.path import join |
|
|
| ABS_PATH = join(os.path.dirname(os.path.abspath(__file__)), "DSINE") |
|
|
|
|
| if __name__ == "__main__": |
|
|
| parser = argparse.ArgumentParser(description="") |
| parser.add_argument("--num_samples", default=1, type=int) |
| parser.add_argument("--image_resolution", default=768, type=int) |
| parser.add_argument("--strength", default=1.0, type=float) |
| parser.add_argument("--ng_scale", default=1.0, type=float) |
| parser.add_argument("--ddim_steps", default=10, type=int) |
| parser.add_argument("--seed", default=23012, type=int) |
| parser.add_argument("--eta", default=0.0, type=float) |
| parser.add_argument("--temperature", default=0.0, type=float) |
| parser.add_argument("--save_memory", action="store_true") |
| parser.add_argument( |
| "--negative_prompt", |
| default="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", |
| type=str, |
| ) |
| parser.add_argument("--input", "-i", default=None, type=str) |
| parser.add_argument( |
| "--prior", default="DSINE", type=str, choices=["DSINE", "geowizard"] |
| ) |
| parser.add_argument( |
| "--flip", action="store_true", help="flip init normal and out normal" |
| ) |
| parser.add_argument("--wo_center", action="store_true", help="without center crop") |
| parser.add_argument("--num_gpu", default=1, type=int) |
| parser.add_argument("--rank", default=0, type=int) |
|
|
| opt = parser.parse_args() |
|
|
| assert opt.input is not None and os.path.exists(opt.input) |
|
|
| if os.path.isfile(opt.input): |
| input_list = [opt.input] |
| else: |
| input_name_list = sorted(os.listdir(opt.input)) |
| input_list = [os.path.join(opt.input, name) for name in input_name_list] |
|
|
| all_data = tu.is_img(input_list) |
|
|
| num_samples = opt.num_samples |
| image_resolution = opt.image_resolution |
| strength = opt.strength |
| neg_scale = opt.ng_scale |
| guess_mode = False |
| ddim_steps = opt.ddim_steps |
| seed = opt.seed |
| eta = opt.eta |
| temperature = opt.temperature |
| save_memory = opt.save_memory |
|
|
| tu.seed_everything(seed, verbose=True) |
| |
|
|
| if opt.num_gpu > 1: |
| bucket = len(all_data) // opt.num_gpu |
| rank = opt.rank |
| if rank == opt.num_gpu - 1: |
| all_data = all_data[bucket * rank :] |
| else: |
| all_data = all_data[bucket * rank : bucket * (rank + 1)] |
|
|
| if opt.prior == "DSINE": |
| normal_predictor = torch.hub.load( |
| ABS_PATH, |
| "DSINE", |
| local_file_path="./pretrained_models/dsine.pt", |
| source="local", |
| ) |
| else: |
| raise NotImplementedError |
|
|
| if torch.cuda.is_available(): |
| current_device_id = torch.cuda.current_device() |
| device = f"cuda:{current_device_id}" |
| else: |
| device = "cpu" |
|
|
| output_dir = os.path.join(opt.input, "normal") |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| for item in tqdm(all_data): |
|
|
| input_image_path = item |
| basename = os.path.basename(item) |
|
|
| if opt.wo_center: |
| input_image = cv2.imread(input_image_path) |
| else: |
| input_image = center_crop(cv2.imread(input_image_path)) |
|
|
| height, width = input_image.shape[:2] |
|
|
| with torch.no_grad(): |
| raw_input_image = HWC3(input_image) |
| ori_H, ori_W, _ = raw_input_image.shape |
|
|
| img = resize_image(raw_input_image, image_resolution) |
|
|
| H, W, C = img.shape |
| if opt.prior == "DSINE": |
| pred_normal = normal_predictor.infer_cv2(img)[0] |
| pred_normal = (pred_normal + 1) / 2 * 255 |
| pred_normal = pred_normal.cpu().numpy().transpose(1, 2, 0) |
|
|
| pred_normal = cv2.cvtColor( |
| pred_normal.astype(np.uint8), cv2.COLOR_RGB2BGR |
| ) |
| elif opt.prior == "geowizard": |
| pred_normal = normal_predictor(img, image_resolution) |
| pred_normal = (pred_normal + 1) / 2 * 255 |
| pred_normal = cv2.cvtColor( |
| pred_normal.astype(np.uint8), cv2.COLOR_RGB2BGR |
| ) |
|
|
| pred_normal = cv2.resize(pred_normal, (ori_W, ori_H)) |
|
|
| basename = os.path.splitext(basename)[0] |
|
|
| cv2.imwrite(f"{output_dir}/normal_{basename}.png", pred_normal) |
|
|