from typing import List from PIL import Image import numpy as np import math import random import cv2 from typing import List import torch import einops from pytorch_lightning import seed_everything from transparent_background import Remover from dataset.opencv_transforms.functional import to_tensor, center_crop from vtdm.model import create_model from vtdm.util import tensor2vid remover = Remover(jit=False) def pil_to_cv2(pil_image: Image.Image) -> np.ndarray: cv_image = np.array(pil_image) cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR) return cv_image def prepare_white_image(input_image: Image.Image) -> Image.Image: # remove bg output = remover.process(input_image, type='rgba') # expand image width, height = output.size max_side = max(width, height) white_image = Image.new('RGBA', (max_side, max_side), (0, 0, 0, 0)) x_offset = (max_side - width) // 2 y_offset = (max_side - height) // 2 white_image.paste(output, (x_offset, y_offset)) return white_image class MultiViewGenerator: def __init__(self, checkpoint_path, config_path="inference.yaml"): self.models = {} denoising_model = create_model(config_path).cpu() denoising_model.init_from_ckpt(checkpoint_path) denoising_model = denoising_model.cuda().half() self.models["denoising_model"] = denoising_model def denoising(self, frames, args): with torch.no_grad(): C, T, H, W = frames.shape batch = {"video": frames.unsqueeze(0)} batch["elevation"] = ( torch.Tensor([args["elevation"]]).to(torch.int64).to(frames.device) ) batch["fps_id"] = torch.Tensor([7]).to(torch.int64).to(frames.device) batch["motion_bucket_id"] = ( torch.Tensor([127]).to(torch.int64).to(frames.device) ) batch = self.models["denoising_model"].add_custom_cond(batch, infer=True) with torch.autocast(device_type="cuda", dtype=torch.float16): c, uc = self.models[ "denoising_model" ].conditioner.get_unconditional_conditioning( batch, force_uc_zero_embeddings=["cond_frames", "cond_frames_without_noise"], ) additional_model_inputs = { "image_only_indicator": torch.zeros(2, T).to( self.models["denoising_model"].device ), "num_video_frames": batch["num_video_frames"], } def denoiser(input, sigma, c): return self.models["denoising_model"].denoiser( self.models["denoising_model"].model, input, sigma, c, **additional_model_inputs ) with torch.autocast(device_type="cuda", dtype=torch.float16): randn = torch.randn( [T, 4, H // 8, W // 8], device=self.models["denoising_model"].device ) samples = self.models["denoising_model"].sampler(denoiser, randn, cond=c, uc=uc) samples = self.models["denoising_model"].decode_first_stage(samples.half()) samples = einops.rearrange(samples, "(b t) c h w -> b c t h w", t=T) return tensor2vid(samples) def video_pipeline(self, frames, args) -> List[Image.Image]: num_iter = args["num_iter"] out_list = [] for _ in range(num_iter): with torch.no_grad(): results = self.denoising(frames, args) if len(out_list) == 0: out_list = out_list + results else: out_list = out_list + results[1:] img = out_list[-1] img = to_tensor(img) img = (img - 0.5) * 2.0 frames[:, 0] = img result = [] for i, frame in enumerate(out_list): input_image = Image.fromarray(frame) output_image = remover.process(input_image, type='rgba') result.append(output_image) return result def process(self, white_image: Image.Image, args) -> List[Image.Image]: img = pil_to_cv2(white_image) frame_list = [img] * args["clip_size"] h, w = frame_list[0].shape[0:2] rate = max( args["input_resolution"][0] * 1.0 / h, args["input_resolution"][1] * 1.0 / w ) frame_list = [ cv2.resize(f, [math.ceil(w * rate), math.ceil(h * rate)]) for f in frame_list ] frame_list = [ center_crop(f, [args["input_resolution"][0], args["input_resolution"][1]]) for f in frame_list ] frame_list = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frame_list] frame_list = [to_tensor(f) for f in frame_list] frame_list = [(f - 0.5) * 2.0 for f in frame_list] frames = torch.stack(frame_list, 1) frames = frames.cuda() self.models["denoising_model"].num_samples = args["clip_size"] self.models["denoising_model"].image_size = args["input_resolution"] return self.video_pipeline(frames, args) def infer(self, white_image: Image.Image) -> List[Image.Image]: seed = random.randint(0, 65535) seed_everything(seed) params = { "clip_size": 25, "input_resolution": [512, 512], "num_iter": 1, "aes": 6.0, "mv": [0.0, 0.0, 0.0, 10.0], "elevation": 0, } return self.process(white_image, params)