LHMPP / engine /NormalEstimator /i2normal.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
"""
Lingteng Qiu
Baseline I2Normal to show the difference demo.
"""
import sys
sys.path.append("./")
import cv2
import einops
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
import json
import pdb
import argparse
import tqlt
import random
from tqlt import utils as tu
from tqlt import op as tqlo
from human_generate_system.engineer.NormalEstimator.data_utils import (
HWC3,
resize_image,
norm_normalize,
center_crop,
flip_x,
)
from PIL import Image
from os.path import join
ABS_PATH = join(os.path.dirname(os.path.abspath(__file__)), "DSINE")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument("--num_samples", default=1, type=int)
parser.add_argument("--image_resolution", default=768, type=int)
parser.add_argument("--strength", default=1.0, type=float)
parser.add_argument("--ng_scale", default=1.0, type=float)
parser.add_argument("--ddim_steps", default=10, type=int)
parser.add_argument("--seed", default=23012, type=int)
parser.add_argument("--eta", default=0.0, type=float)
parser.add_argument("--temperature", default=0.0, type=float)
parser.add_argument("--save_memory", action="store_true")
parser.add_argument(
"--negative_prompt",
default="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
type=str,
)
parser.add_argument("--input", "-i", default=None, type=str)
parser.add_argument(
"--prior", default="DSINE", type=str, choices=["DSINE", "geowizard"]
)
parser.add_argument(
"--flip", action="store_true", help="flip init normal and out normal"
)
parser.add_argument("--wo_center", action="store_true", help="without center crop")
parser.add_argument("--num_gpu", default=1, type=int)
parser.add_argument("--rank", default=0, type=int)
opt = parser.parse_args()
assert opt.input is not None and os.path.exists(opt.input)
if os.path.isfile(opt.input):
input_list = [opt.input]
else:
input_name_list = sorted(os.listdir(opt.input))
input_list = [os.path.join(opt.input, name) for name in input_name_list]
all_data = tu.is_img(input_list)
num_samples = opt.num_samples
image_resolution = opt.image_resolution
strength = opt.strength
neg_scale = opt.ng_scale
guess_mode = False
ddim_steps = opt.ddim_steps
seed = opt.seed
eta = opt.eta
temperature = opt.temperature
save_memory = opt.save_memory
tu.seed_everything(seed, verbose=True)
# all_data = sorted(all_data, key=lambda x: x['image'])
if opt.num_gpu > 1:
bucket = len(all_data) // opt.num_gpu
rank = opt.rank
if rank == opt.num_gpu - 1:
all_data = all_data[bucket * rank :]
else:
all_data = all_data[bucket * rank : bucket * (rank + 1)]
if opt.prior == "DSINE":
normal_predictor = torch.hub.load(
ABS_PATH,
"DSINE",
local_file_path="./pretrained_models/dsine.pt",
source="local",
)
else:
raise NotImplementedError
if torch.cuda.is_available():
current_device_id = torch.cuda.current_device()
device = f"cuda:{current_device_id}"
else:
device = "cpu"
output_dir = os.path.join(opt.input, "normal")
os.makedirs(output_dir, exist_ok=True)
for item in tqdm(all_data):
input_image_path = item
basename = os.path.basename(item)
if opt.wo_center:
input_image = cv2.imread(input_image_path)
else:
input_image = center_crop(cv2.imread(input_image_path))
height, width = input_image.shape[:2]
with torch.no_grad():
raw_input_image = HWC3(input_image)
ori_H, ori_W, _ = raw_input_image.shape
img = resize_image(raw_input_image, image_resolution)
H, W, C = img.shape
if opt.prior == "DSINE":
pred_normal = normal_predictor.infer_cv2(img)[0] # (3, H, W)
pred_normal = (pred_normal + 1) / 2 * 255
pred_normal = pred_normal.cpu().numpy().transpose(1, 2, 0)
pred_normal = cv2.cvtColor(
pred_normal.astype(np.uint8), cv2.COLOR_RGB2BGR
)
elif opt.prior == "geowizard":
pred_normal = normal_predictor(img, image_resolution)
pred_normal = (pred_normal + 1) / 2 * 255
pred_normal = cv2.cvtColor(
pred_normal.astype(np.uint8), cv2.COLOR_RGB2BGR
)
pred_normal = cv2.resize(pred_normal, (ori_W, ori_H))
basename = os.path.splitext(basename)[0]
cv2.imwrite(f"{output_dir}/normal_{basename}.png", pred_normal)