Spaces:
Configuration error
Configuration error
| import torch | |
| from diffusers import StableVideoDiffusionPipeline | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import rembg | |
| import argparse | |
| import imageio | |
| import os | |
| def add_margin(pil_img, top, right, bottom, left, color): | |
| width, height = pil_img.size | |
| new_width = width + right + left | |
| new_height = height + top + bottom | |
| result = Image.new(pil_img.mode, (new_width, new_height), color) | |
| result.paste(pil_img, (left, top)) | |
| return result | |
| def resize_image(image, output_size=(1024, 576)): | |
| image = image.resize((output_size[1],output_size[1])) | |
| pad_size = (output_size[0]-output_size[1]) //2 | |
| image = add_margin(image, 0, pad_size, 0, pad_size, tuple(np.array(image)[0,0])) | |
| return image | |
| def load_image(file, W, H, bg='white'): | |
| # load image | |
| print(f'[INFO] load image from {file}...') | |
| img = cv2.imread(file, cv2.IMREAD_UNCHANGED) | |
| bg_remover = rembg.new_session() | |
| img = rembg.remove(img, session=bg_remover) | |
| img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) | |
| img = img.astype(np.float32) / 255.0 | |
| input_mask = img[..., 3:] | |
| # white bg | |
| if bg == 'white': | |
| input_img = img[..., :3] * input_mask + (1 - input_mask) | |
| elif bg == 'black': | |
| input_img = img[..., :3] | |
| else: | |
| raise NotImplementedError | |
| # bgr to rgb | |
| input_img = input_img[..., ::-1].copy() | |
| input_img = Image.fromarray(np.uint8(input_img*255)) | |
| return input_img | |
| def load_image_w_bg(file, W, H): | |
| # load image | |
| print(f'[INFO] load image from {file}...') | |
| img = cv2.imread(file, cv2.IMREAD_UNCHANGED) | |
| img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) | |
| img = img.astype(np.float32) / 255.0 | |
| input_img = img[..., :3] | |
| # bgr to rgb | |
| input_img = input_img[..., ::-1].copy() | |
| input_img = Image.fromarray(np.uint8(input_img*255)) | |
| return input_img | |
| def gen_vid(input_path, seed, bg, is_pad): | |
| name = input_path.split('/')[-1].split('.')[0] | |
| input_dir = os.path.dirname(input_path) | |
| pipe = StableVideoDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16" | |
| ) | |
| # pipe.enable_model_cpu_offload() | |
| pipe.to("cuda") | |
| if is_pad: | |
| height, width = 576, 1024 | |
| else: | |
| height, width = 512, 512 | |
| if seed is None: | |
| for bg in ['white', 'black', 'orig']: | |
| if bg == 'orig': | |
| if 'rgba' in name: | |
| continue | |
| image = load_image_w_bg(input_path, width, height) | |
| else: | |
| image = load_image(input_path, width, height, bg) | |
| if is_pad: | |
| image = resize_image(image, output_size=(width, height)) | |
| for seed in range(20): | |
| generator = torch.manual_seed(seed) | |
| frames = pipe(image, height, width, generator=generator).frames[0] | |
| imageio.mimwrite(f"{input_dir}/videos/{name}_{bg}_{seed:03}.mp4", frames, fps=7) | |
| else: | |
| if bg == 'orig': | |
| if 'rgba' in name: | |
| raise ValueError | |
| image = load_image_w_bg(input_path, width, height) | |
| else: | |
| image = load_image(input_path, width, height, bg) | |
| if is_pad: | |
| image = resize_image(image, output_size=(width, height)) | |
| generator = torch.manual_seed(seed) | |
| frames = pipe(image, height, width, generator=generator).frames[0] | |
| imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7) | |
| os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True) | |
| for idx, img in enumerate(frames): | |
| if is_pad: | |
| img = img.crop(((width-height) //2, 0, width - (width-height) //2, height)) | |
| img.save(f"{input_dir}/{name}_frames/{idx:03}.png") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--path", type=str, required=True) | |
| parser.add_argument("--seed", type=int, default=None) | |
| parser.add_argument("--bg", type=str, default='white') | |
| parser.add_argument("--is_pad", type=bool, default=False) | |
| args, extras = parser.parse_known_args() | |
| gen_vid(args.path, args.seed, args.bg, args.is_pad) | |