File size: 4,207 Bytes
21c4e64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch

from diffusers import StableVideoDiffusionPipeline

from PIL import Image
import numpy as np

import cv2
import rembg

import argparse
import imageio
import os

def add_margin(pil_img, top, right, bottom, left, color):
    width, height = pil_img.size
    new_width = width + right + left
    new_height = height + top + bottom
    result = Image.new(pil_img.mode, (new_width, new_height), color)
    result.paste(pil_img, (left, top))
    return result

def resize_image(image, output_size=(1024, 576)):
    image = image.resize((output_size[1],output_size[1]))
    pad_size = (output_size[0]-output_size[1]) //2
    image = add_margin(image, 0, pad_size, 0, pad_size, tuple(np.array(image)[0,0]))
    return image


def load_image(file, W, H, bg='white'):
    # load image
    print(f'[INFO] load image from {file}...')
    img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
    bg_remover = rembg.new_session()
    img = rembg.remove(img, session=bg_remover)
    img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
    img = img.astype(np.float32) / 255.0
    input_mask = img[..., 3:]
    # white bg
    if bg == 'white':
        input_img = img[..., :3] * input_mask + (1 - input_mask)
    elif bg == 'black':
        input_img = img[..., :3]
    else:
        raise NotImplementedError
    # bgr to rgb
    input_img = input_img[..., ::-1].copy()
    input_img = Image.fromarray(np.uint8(input_img*255))
    return input_img

def load_image_w_bg(file, W, H):
    # load image
    print(f'[INFO] load image from {file}...')
    img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
    img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
    img = img.astype(np.float32) / 255.0
    input_img = img[..., :3]
    # bgr to rgb
    input_img = input_img[..., ::-1].copy()
    input_img = Image.fromarray(np.uint8(input_img*255))
    return input_img

def gen_vid(input_path, seed, bg, is_pad):
    name = input_path.split('/')[-1].split('.')[0]
    input_dir = os.path.dirname(input_path)
    pipe = StableVideoDiffusionPipeline.from_pretrained(
        "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
    )
    # pipe.enable_model_cpu_offload()
    pipe.to("cuda")

    if is_pad:
        height, width = 576, 1024
    else:
        height, width = 512, 512

    if seed is None:
        for bg in ['white', 'black', 'orig']:
            if bg == 'orig':
                if 'rgba' in name:
                    continue
                image = load_image_w_bg(input_path, width, height)
            else:
                image = load_image(input_path, width, height, bg)
            if is_pad:
                image = resize_image(image, output_size=(width, height))
            for seed in range(20):
                generator = torch.manual_seed(seed)
                frames = pipe(image, height, width, generator=generator).frames[0]
                imageio.mimwrite(f"{input_dir}/videos/{name}_{bg}_{seed:03}.mp4", frames, fps=7)
    else:
        if bg == 'orig':
            if 'rgba' in name:
                raise ValueError
            image = load_image_w_bg(input_path, width, height)
        else:
            image = load_image(input_path, width, height, bg)
        if is_pad:
            image = resize_image(image, output_size=(width, height))
        generator = torch.manual_seed(seed)
        frames = pipe(image, height, width, generator=generator).frames[0]

        imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7)
        os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True)
        for idx, img in enumerate(frames):
            if is_pad:
                img = img.crop(((width-height) //2, 0, width - (width-height) //2, height))
            img.save(f"{input_dir}/{name}_frames/{idx:03}.png")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str, required=True)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--bg", type=str, default='white')
    parser.add_argument("--is_pad", type=bool, default=False)
    args, extras = parser.parse_known_args()
    gen_vid(args.path, args.seed, args.bg, args.is_pad)