# from utils.args import parse_args import logging import os import argparse from pathlib import Path from PIL import Image import numpy as np import torch from tqdm.auto import tqdm from diffusers.utils import check_min_version from pipeline import LotusGPipeline, LotusDPipeline from utils.image_utils import colorize_depth_map from utils.seed_all import seed_all from contextlib import nullcontext import cv2 check_min_version('0.28.0.dev0') def infer_pipe(pipe, test_image, task_name, seed, device, video_depth=False): if seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(seed) if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(pipe.device.type) with autocast_ctx: if video_depth == False: test_image = Image.open(test_image).convert('RGB') test_image = np.array(test_image).astype(np.float16) test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0) test_image = test_image / 127.5 - 1.0 test_image = test_image.to(device) task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1) # Run pred = pipe( rgb_in=test_image, prompt='', num_inference_steps=1, generator=generator, # guidance_scale=0, output_type='np', timesteps=[999], task_emb=task_emb, ).images[0] # Post-process the prediction if task_name == 'depth': output_npy = pred.mean(axis=-1) output_color = colorize_depth_map(output_npy, reverse_color=True) else: output_npy = pred output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) return output_color def infer_pipe_video(pipe, test_image, task_name, generator, device, latents=None): if torch.backends.mps.is_available(): autocast_ctx = nullcontext() else: autocast_ctx = torch.autocast(pipe.device.type) with autocast_ctx: test_image = np.array(test_image).astype(np.float16) test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0) test_image = test_image / 127.5 - 1.0 test_image = test_image.to(device) task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1) # Run output = pipe( rgb_in=test_image, prompt='', num_inference_steps=1, generator=generator, latents=latents, # guidance_scale=0, output_type='np', timesteps=[999], task_emb=task_emb, return_dict=False ) pred = output[0][0] last_frame_latent = output[2] # Post-process the prediction if task_name == 'depth': output_npy = pred.mean(axis=-1) output_color = colorize_depth_map(output_npy, reverse_color=True) else: output_npy = pred output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) return output_color, last_frame_latent def load_pipe(task_name, device): if task_name == 'depth': model_g = 'jingheya/lotus-depth-g-v2-0-disparity' model_d = 'jingheya/lotus-depth-d-v2-0-disparity' else: model_g = 'jingheya/lotus-normal-g-v1-0' model_d = 'jingheya/lotus-normal-d-v1-0' dtype = torch.float16 pipe_g = LotusGPipeline.from_pretrained( model_g, torch_dtype=dtype, ) pipe_d = LotusDPipeline.from_pretrained( model_d, torch_dtype=dtype, ) pipe_g.to(device) pipe_d.to(device) pipe_g.set_progress_bar_config(disable=True) pipe_d.set_progress_bar_config(disable=True) logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.") return pipe_g, pipe_d def lotus_video(input_video, task_name, seed, device): pipe_g, pipe_d = load_pipe(task_name, device) # load the video and split it into frames cap = cv2.VideoCapture(input_video) fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() # generate latents_common for lotus-g if seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(seed) last_frame_latent = None latent_common = torch.randn( (1, 4, height // pipe_g.vae_scale_factor, width // pipe_g.vae_scale_factor), generator=generator, dtype=pipe_g.dtype, device=device ) output_g = [] output_d = [] for frame in frames: latents = latent_common if last_frame_latent is not None: latents = 0.9 * latents + 0.1 * last_frame_latent output_frame_g, last_frame_latent = infer_pipe_video(pipe_g, frame, task_name, seed, device, latents) output_frame_d = infer_pipe(pipe_d, frame, task_name, seed, device, video_depth=True) output_g.append(output_frame_g) output_d.append(output_frame_d) return output_g, output_d, fps def lotus(image_input, task_name, seed, device): pipe_g, pipe_d = load_pipe(task_name, device) output_g = infer_pipe(pipe_g, image_input, task_name, seed, device) output_d = infer_pipe(pipe_d, image_input, task_name, seed, device) return output_g, output_d def parse_args(): '''Set the Args''' parser = argparse.ArgumentParser( description="Run Lotus..." ) # model settings parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="pretrained model path from hugging face or local dir", ) parser.add_argument( "--prediction_type", type=str, default="sample", help="The used prediction_type. ", ) parser.add_argument( "--timestep", type=int, default=999, ) parser.add_argument( "--mode", type=str, default="regression", # "generation" help="Whether to use the generation or regression pipeline." ) parser.add_argument( "--task_name", type=str, default="depth", # "normal" ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) # inference settings parser.add_argument("--seed", type=int, default=None, help="Random seed.") parser.add_argument( "--output_dir", type=str, required=True, help="Output directory." ) parser.add_argument( "--input_dir", type=str, required=True, help="Input directory." ) parser.add_argument( "--half_precision", action="store_true", help="Run with half-precision (16-bit float), might lead to suboptimal result.", ) args = parser.parse_args() return args def main(): logging.basicConfig(level=logging.INFO) logging.info(f"Run inference...") args = parse_args() # -------------------- Preparation -------------------- # Random seed if args.seed is not None: seed_all(args.seed) # Output directories os.makedirs(args.output_dir, exist_ok=True) logging.info(f"Output dir = {args.output_dir}") output_dir_color = os.path.join(args.output_dir, f'{args.task_name}_vis') output_dir_npy = os.path.join(args.output_dir, f'{args.task_name}') if not os.path.exists(output_dir_color): os.makedirs(output_dir_color) if not os.path.exists(output_dir_npy): os.makedirs(output_dir_npy) # half_precision if args.half_precision: dtype = torch.float16 logging.info(f"Running with half precision ({dtype}).") else: dtype = torch.float16 # -------------------- Device -------------------- if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") logging.warning("CUDA is not available. Running on CPU will be slow.") logging.info(f"Device = {device}") # -------------------- Data -------------------- root_dir = Path(args.input_dir) test_images = list(root_dir.rglob('*.png')) + list(root_dir.rglob('*.jpg')) test_images = sorted(test_images) print('==> There are', len(test_images), 'images for validation.') # -------------------- Model -------------------- if args.mode == 'generation': pipeline = LotusGPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=dtype, ) elif args.mode == 'regression': pipeline = LotusDPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=dtype, ) else: raise ValueError(f'Invalid mode: {args.mode}') logging.info(f"Successfully loading pipeline from {args.pretrained_model_name_or_path}.") pipeline = pipeline.to(device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(args.seed) # -------------------- Inference and saving -------------------- with torch.no_grad(): for i in tqdm(range(len(test_images))): # Preprocess validation image test_image = Image.open(test_images[i]).convert('RGB') test_image = np.array(test_image).astype(np.float16) test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0) test_image = test_image / 127.5 - 1.0 test_image = test_image.to(device) task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1) # Run pred = pipeline( rgb_in=test_image, prompt='', num_inference_steps=1, generator=generator, # guidance_scale=0, output_type='np', timesteps=[args.timestep], task_emb=task_emb, ).images[0] # Post-process the prediction save_file_name = os.path.basename(test_images[i])[:-4] if args.task_name == 'depth': output_npy = pred.mean(axis=-1) output_color = colorize_depth_map(output_npy) else: output_npy = pred output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png')) np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy) print('==> Inference is done. \n==> Results saved to:', args.output_dir) if __name__ == '__main__': main()