File size: 6,141 Bytes
44f2ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os

import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))

import warnings

import cv2
import numpy as np
import tqdm
import torch
import torch.nn.functional as F
import torchvision.io as vision_io



from models.pipelines import TextToVideoSDPipelineSpatialAware
from diffusers.utils import export_to_video
from PIL import Image
import torchvision



import warnings
warnings.filterwarnings("ignore")

OUTPUT_PATH = "/scr/demo"

def generate_video(pipe, overall_prompt, latents, get_latents=False, num_frames=24, num_inference_steps=50, fg_masks=None, 
        fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
    
    video_frames = pipe(overall_prompt, num_frames=num_frames, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks, 
    frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, fg_prompt=fg_prompt,
    make_attention_mask_2d=True, attention_mask_block_diagonal=True, height=320, width=576 ).frames
    if get_latents:
        video_latents = pipe(overall_prompt, num_frames=num_frames, latents=latents, num_inference_steps=num_inference_steps, output_type="latent").frames
        return video_frames, video_latents
    
    return video_frames

def save_frames(path):
    video, audio, video_info = vision_io.read_video(f"{path}.mp4", pts_unit='sec')

    # Number of frames
    num_frames = video.size(0)

    # Save each frame
    os.makedirs(f"{path}", exist_ok=True)
    for i in range(num_frames):
        frame = video[i, :, :, :].numpy()
        # Convert from C x H x W to H x W x C and from torch tensor to PIL Image
        # frame = frame.permute(1, 2, 0).numpy()
        img = Image.fromarray(frame.astype('uint8'))
        img.save(f"{path}/frame_{i:04d}.png")

if __name__ == "__main__":
    # Example usage
    num_frames = 24
    save_path = "video"
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    random_latents = torch.randn([1, 4, num_frames, 40, 72], generator=torch.Generator().manual_seed(2)).to(torch_device)
    
    try:
        pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
            "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
    except:
        pipe = TextToVideoSDPipelineSpatialAware.from_pretrained(
            "cerspense/zeroscope_v2_576w", torch_dtype=torch.float, variant="fp32").to(torch_device)
        
    # Generate video


    bbox_mask = torch.zeros([24, 1, 40, 72], device=torch_device)
    bbox_mask_2 = torch.zeros([24, 1, 40, 72], device=torch_device)

    
    x_start = [10 + (i % 3) for i in range(num_frames)]  # Simulating slight movement in x
    x_end = [30 + (i % 3) for i in range(num_frames)]    # Simulating slight movement in x
    y_start = [10 for _ in range(num_frames)]            # Static y start as the bear is seated/standing
    y_end = [25 for _ in range(num_frames)]              # Static y end, considering the size of the guitar

    # Populate the bbox_mask tensor with ones where the bounding box is located
    for i in range(num_frames):
        bbox_mask[i, :, x_start[i]:x_end[i], y_start[i]:y_end[i]] = 1
        bbox_mask_2[i, :, x_start[i]:x_end[i], 72-y_end[i]:72-y_start[i]] = 1

    # fg_masks = bbox_mask
    fg_masks = [bbox_mask, bbox_mask_2]
    
       

    frozen_prompt = None
    fg_masked_latents = None
    fg_objects = []
    prompts = []
    prompts = [
        (["cat", "goldfish bowl"], "A cat curiously staring at a goldfish bowl on a sunny windowsill."),
        (["Superman", "Batman"], "Superman and Batman standing side by side in a heroic pose against a city skyline."),
        (["rose", "daisy"], "A rose and a daisy in a small vase on a rustic wooden table."),
        (["Harry Potter", "Hermione Granger"], "Harry Potter and Hermione Granger studying a magical map."),
        (["butterfly", "dragonfly"], "A butterfly and a dragonfly resting on a leaf in a vibrant garden."),
        (["teddy bear", "toy train"], "A teddy bear and a toy train on a child's playmat in a brightly lit room."),
        (["frog", "turtle"], "A frog and a turtle sitting on a lily pad in a serene pond."),
        (["Mickey Mouse", "Donald Duck"], "Mickey Mouse and Donald Duck enjoying a day at the beach, building a sandcastle."),
        (["penguin", "seal"], "A penguin and a seal lounging on an iceberg in the Antarctic."),
        (["lion", "zebra"], "A lion and a zebra peacefully drinking water from the same pond in the savannah.")
    ]

    for fg_object, overall_prompt in prompts:
        os.makedirs(f"{OUTPUT_PATH}/{save_path}/{overall_prompt}-mask", exist_ok=True)
        try:
            for i in range(num_frames):
                torchvision.utils.save_image(fg_masks[0][i,0], f"{OUTPUT_PATH}/{save_path}/{overall_prompt}-mask/frame_{i:04d}_0.png")
                torchvision.utils.save_image(fg_masks[1][i,0], f"{OUTPUT_PATH}/{save_path}/{overall_prompt}-mask/frame_{i:04d}_1.png")
        except:
            pass
        print(fg_object, overall_prompt)
        seed = 2
        random_latents = torch.randn([1, 4, num_frames, 40, 72], generator=torch.Generator().manual_seed(seed)).to(torch_device)
        for num_inference_steps in range(40,50,10):
            for frozen_steps in [0, 1, 2]:
                video_frames = generate_video(pipe, overall_prompt, random_latents, get_latents=False, num_frames=num_frames, num_inference_steps=num_inference_steps, 
                    fg_masks=fg_masks, fg_masked_latents=fg_masked_latents, frozen_steps=frozen_steps, frozen_prompt=frozen_prompt, fg_prompt=fg_object)
                # Save video frames
                os.makedirs(f"{OUTPUT_PATH}/{save_path}/{overall_prompt}", exist_ok=True)
                video_path = export_to_video(video_frames, f"{OUTPUT_PATH}/{save_path}/{overall_prompt}/{frozen_steps}_of_{num_inference_steps}_{seed}_masked.mp4")
                save_frames(f"{OUTPUT_PATH}/{save_path}/{overall_prompt}/{frozen_steps}_of_{num_inference_steps}_{seed}_masked")